This is an abbreviated version of the c1_onboarding.ipynb notebook that specifically focuses on artifacts

## 🪄 Install `wandb` library and login


The first step on our journey is to install the client, which is as easy as:



In [None]:
!pip install c1_aiml_aem -qU

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.6/19.6 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip install "numpy<1.26.4"



In [None]:
!pip install scikit-learn



## Log in to W&B
- You can explicitly login using `wandb login` or `wandb.login()` (See below)
- Alternatively you can set environment variables. There are several env variables which you can set to change the behavior of W&B logging. The most important are:
    - `WANDB_API_KEY` - find this in your "Settings" section under your profile
    - `WANDB_BASE_URL` - this is the url of the W&B server
- Find your API Token in "Profile" -> "Setttings" in the W&B App


In [None]:
import numpy as np
import random
from c1_aiml_aem import wandb

In [None]:
## Replace this with Cap1 Instance url
WANDB_HOST = "https://wandb.cloud.capitalone.com" #@param
# Equivalent to running "wandb login" in your shell

wandb.login(host= WANDB_HOST)

#
# Note that https://api.wandb.ai is the default and points to the publicly hosted
# app.
#
# Alternative you can configure this with environment variables:
# export WANDB_API_KEY="<your-api-key>"
# export WANDB_BASE_URL="<your-wandb-endpoint>"

Calling `wandb login` or `wandb.login` will write your API key to your `~/.netrc` file. __To authenticate the client in a headless job on the cloud, you will definitely want to use the `WANDB_API_KEY` environment variable__.

**Default Destination:** When a user signs up to the instance and joins a team, wandb will automatically write runs this team. This setting is controlled directly through your settings and can be updated by

*   Visiting https://<host-url>/settings
*   Look for `Default Team` section
*   Updating `Default location to create new projects` to entity of choice

In [None]:
import random
import math

WANDB_ENTITY = 'wb-new-user-training-20251008' #@param #Point to a team you are a member of!
WANDB_PROJECT = "workshop_wandb_intro" #@param
YOUR_NAME = "uma" #@param #We will use this for our filtering and grouping to make it easy for your to identify your runs in the project

### Log Dataframes of Media

You can also log `pandas.DataFrame` objects with `.log`! These will be converted into a `wandb.Table` (docs) and interactievly displayed inside of W&B.

Note: One of the most powerful features of `wandb.Table`s is that you can include any `wandb` type as a cell value! This includes, images, plots, videos, audio... almost anything 🤩

Below we will use a the Oxford-IIIT Pet Dataset of 37 different pet breeds along with corresponding segmentation masks provided in the annotations for logging media example

In [None]:
!curl -SL -qq https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz > images.tar.gz
!curl -SL -qq https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz > annotations.tar.gz
!tar -xzf images.tar.gz
!tar -xzf annotations.tar.gz

In [None]:
import os
import wandb
import numpy as np
from PIL import Image
from pathlib import Path

In [None]:
#Utility functions

# Function to load an image and mask
def load_image_and_mask(image_path, mask_path):
    image = np.array(Image.open(image_path))
    mask = np.array(Image.open(mask_path))
    return image, mask

# Function to create W&B mask overlay
def wb_mask(image, mask):
    return wandb.Image(image, masks={"predictions": {"mask_data": mask}}, caption="Segmentation Image")

def log_single_images(path_img, path_lbl):
  # Single Image and Mask Logging
  image_path = path_img / 'Abyssinian_1.jpg'
  mask_path = path_lbl / 'Abyssinian_1.png'
  image_np, mask_np = load_image_and_mask(image_path, mask_path)
  return image_np, mask_np

def wandb_table_multiple_imags(path_img,path_lbl, num_images):
    table = wandb.Table(columns=["ID", "Original Image", "Image with Mask"])
    # Multiple Images Logging in a Table
    table = wandb.Table(columns=["ID", "Original Image", "Image with Mask"])

    # Log first X images and their masks to the table
    for each in os.listdir(path_img)[:num_images]:  # limiting to first X images
        image_path = path_img / each
        mask_path = path_lbl / f'{Path(each).stem}.png'  # Adjust to match mask filenames
        image_np, mask_np = load_image_and_mask(image_path, mask_path)

        # Create mask overlay using W&B
        mask_img = wb_mask(image_np, mask_np)

        # Add image path, original image, and image with mask to the table
        table.add_data(str(image_path), wandb.Image(image_np), mask_img)

    return table

Logging a single image and a table of images with segmentation masks

In [None]:
run = wandb.init(
    entity = WANDB_ENTITY,
    project=WANDB_PROJECT,
    group=YOUR_NAME,
    name="logging_rich_media",
    )

# Define paths
path_img = Path('images')
path_lbl = Path('annotations/trimaps')

image_np, mask_np = log_single_images(path_img, path_lbl)

# Log single image and segmentation mask
run.log({
    "input_image": wandb.Image(image_np, caption="Input Image"),
    "segmentation_mask": wb_mask(image_np, mask_np)
})

#Log tables of images and segmentation mask
image_tables = wandb_table_multiple_imags(path_img, path_lbl, 30)

# Log the table
run.log({"Segmentation Table": image_tables})

run.finish()

### Log Sequences of Media

If you periodically call `run.log` to log a number (for example, loss), Weights & Biases will automatically render a line plot showing the change in that value over time (a loss curve). You can also log media under a key more than once over the course of an experiment, in which case Weights & Biases will display that media with a step slider so you can scrub over the course of the experiment and see how it changed. This is particularly useful for seeing how model predictions and visualizations of model performance (e.g. a precision/recall curve) change over time. In the example below, we log a `wandb.Image` repeatedly just to demonstrate how this works. Below is an example of doing the same with audio.

In [None]:
%%sh
curl https://parade.com/.image/t_share/MTkwNTgwOTUyNjU2Mzg5MjQ1/albert-einstein-quotes-jpg.jpg > image.jpg

In [None]:
from PIL import Image, ImageFilter
import pandas as pd
# Load image with pillow, resize to 512 square
im = Image.open("./image.jpg").resize((512, 512))
images = []
with wandb.init(project = WANDB_PROJECT) as run:

  for step in range(10):

    # Log image
    images.append( (step, wandb.Image(im)))
    run.log({"image": wandb.Image(im)})

    # Apply small Gaussian blur
    im = im.filter(ImageFilter.GaussianBlur(radius=1.5))

  # Also log the images + associated logging step to a W&B Table
  run.log({ "images_df": pd.DataFrame( images, columns = ["step", "images"])})

# Artifacts

Use W&B Artifacts to track and version data as the inputs and outputs of your W&B Runs. In addition to logging hyperparameters, metadata, and metrics to a run, you can use an artifact to log the dataset used to train the model as input and the resulting model checkpoints as outputs.




For this demo, we are going to go through the workflow of
1. Creating a dataset
2. Logging it to wandb
3. Processing that dataset
4. Logging the processed data to wandb
5. Conducting model training
6. Viewing the entire lineage of this process in the wandb UI

## Create a Dataset
Let's create some datasets that we can work with in this example.

In [None]:
import os
import numpy as np
import csv

directory = "dataset"
os.makedirs(directory, exist_ok=True)
file1, file2 = os.path.join(directory, "file1.csv"), os.path.join(directory, "file2.csv")

def generate_dummy_data(num_samples):
    data = [
        np.random.normal(50, 10, num_samples),
        np.random.randint(1, 100, num_samples),
        np.random.choice(['A', 'B', 'C', 'D'], num_samples),
        np.random.uniform(0.0, 1.0, num_samples)
    ]
    return zip(*data)

def save_to_csv(file, data):
    with open(file, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['feature1', 'feature2', 'feature3', 'feature4'])
        writer.writerows(data)

num_samples = 100
save_to_csv(file1, generate_dummy_data(num_samples))
save_to_csv(file2, generate_dummy_data(num_samples))

The general workflow for creating an Artifact is:

1. Intialize a run.
2. Create an Artifact.
3. Add a any files, directories, or pointers to the new Artifact that you want to track and version.
4. Log the artifact in the W&B platform.

See the [Artifacts Reference guide](https://www.google.com/url?q=https%3A%2F%2Fdocs.wandb.ai%2Fref%2Fpython%2Fartifact) for more information and other commonly used arguments, including how to store additional metadata.

Each time the above `log_artifact` is executed, wandb will create a new version of the Artifact within Weights & Biases if the underlying data has changed.

## Logging this dataset artifact

In [None]:
run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, job_type='log_dataset')

artifact = wandb.Artifact(f"my_first_artifact_{YOUR_NAME}", type="dataset")
# the below will add two individual files to the artifact.
artifact.add_file(local_path=f"{directory}/file1.csv")
artifact.add_file(local_path=f"{directory}/file2.csv")

# or the below if you wanted to add the entire directory contents.
artifact.add_dir(local_path=f"{directory}")
# explictly log the artifact to Weights & Biases.
run.log_artifact(artifact)

run.finish()

## Processing and Consuming the dataset Artifact

When you want to use a specific version of an Artifact in a downstream task, you can specify the specific version you would like to use via either `v0`, `v1`, `v2` and so on, or via specific aliases you may have added. The latest alias always refers to the most recent version of the Artifact logged.

The proceeding code snippet specifies that the W&B Run will use an artifact called my_first_artifact with the alias latest. We will take a step to preprocess our dataset and relog it to wandb, so we can see the lineage up until this point:

In [None]:
run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT, job_type='process_dataset')
artifact = run.use_artifact(artifact_or_name=f"my_first_artifact_{YOUR_NAME}:latest") # this creates a reference within Weights & Biases that this artifact was used by this run.
path = artifact.download() # this downloads the artifact from Weights & Biases to your local system where the code is executing.

print(f"Data directory located at {path}")

In [None]:
processed_directory = "processed_dataset"
os.makedirs(processed_directory, exist_ok=True)
file1, file2 = os.path.join(directory, "file1_processed.csv"), os.path.join(directory, "file2_processed.csv")

In [None]:
# Step 2: Function to process and save the modified CSV data
def process(input_file_path, output_file_path):
    modified_data = []
    with open(input_file_path, 'r') as f:
        reader = csv.reader(f)
        headers = next(reader)  # Skip headers
        for row in reader:
            # Example modification: Adjust feature1 by adding a constant
            row[0] = str(float(row[0]) + 10)  # Modify feature1
            modified_data.append(row)

    # Save the modified data to the output path
    with open(output_file_path, 'w', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(headers)  # Write the headers back
        writer.writerows(modified_data)

# Apply modification to the CSV files
process(os.path.join(path, "file1.csv"), os.path.join(processed_directory, "file1_processed.csv"))
process(os.path.join(path, "file2.csv"), os.path.join(processed_directory, "file2_processed.csv"))

In [None]:
# Step 4: Create a new artifact to store the modified data
processed_artifact = wandb.Artifact(
    f"my_processed_artifact_{YOUR_NAME}",
    type="processed_dataset"
)

# Add the modified CSV files to the new artifact
processed_artifact.add_file(local_path=f"{processed_directory}/file1_processed.csv")
processed_artifact.add_file(local_path=f"{processed_directory}/file1_processed.csv")

# or the below if you wanted to add the entire directory contents.
processed_artifact.add_dir(local_path=f"{processed_directory}")

# Step 5: Log the processed artifact
run.log_artifact(processed_artifact)

# Finish the run
run.finish()

## Using our dataset during model training and logging model checkpoint!

In [None]:
import torch
import torch.nn as nn

run = wandb.init(
    entity = WANDB_ENTITY,
    project=WANDB_PROJECT,
    group=YOUR_NAME,
    job_type = "training",
    config={'param': 42}
)

# Use our processed dataset in our training run
# this creates a reference within Weights & Biases that this artifact was used by this run
artifact = run.use_artifact(artifact_or_name=f"my_processed_artifact_{YOUR_NAME}:latest")

#Save simple neural network model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the model
model = SimpleModel()

# doing some dummy logging here
for i in range(5):
  run.log({"acc": random.random()})

# Save the model and log to artifacts
model_path = "simple_model.pth"
torch.save(model.state_dict(), model_path)

# Log the model as an artifact
art = wandb.Artifact(name=f"simple-model-{YOUR_NAME}", type=f"model")
art.add_file(model_path)
run.log_artifact(art)

run.finish()

For more information on ways to customize your Artifact download, including via the command line, see the [Download and Usage guide](https://docs.wandb.ai/guides/artifacts/download-and-use-an-artifact).

## How can we see which experiments used a particular dataset?


If we want to see which wandb runs consumed a specific dataset, we can do this two ways:
1. Viewing the lineage in the UI
2. Programmatically accessing which run names and ids consumed a specific dataset (below)

In [None]:
run = wandb.init(
    entity = WANDB_ENTITY,
    project=WANDB_PROJECT,
)

artifact = run.use_artifact(artifact_or_name=f"my_processed_artifact_{YOUR_NAME}:latest")

for run in artifact.used_by():
  print(run.name)

## Update Artifact version metadata
You can update the description, metadata, and alias of an artifact on the W&B platform during or outside a W&B Run.

This example changes the description of the my_first_artifact artifact inside a run:

In [None]:
run = wandb.init(entity=WANDB_ENTITY, project=WANDB_PROJECT)
artifact = run.use_artifact(artifact_or_name=f"my_first_artifact_{YOUR_NAME}:latest")
artifact.description = "This is an edited description."
artifact.metadata = {"source": "local disk", "internal data owner": "platform team"}
artifact.save()  # persists changes to an Artifact's properties
run.finish()

## Code Snippet for Logging Reference Artifacts

Artifacts currently support the following URI schemes:

* **http(s)://:** A path to a file accessible over HTTP. The artifact will track checksums in the form of etags and size metadata if the HTTP server supports the ETag and Content-Length response headers.
* **s3://:** A path to an object or object prefix in S3. The artifact will track checksums and versioning information (if the bucket has object versioning enabled) for the referenced objects. Object prefixes are expanded to include the objects under the prefix, default up to 100,000 objects.
* **gs://:** A path to an object or object prefix in GCS. The artifact will track checksums and versioning information (if the bucket has object versioning enabled) for the referenced objects. Object prefixes are expanded to include the objects under the prefix, default up to 100,000 objects.
See below for an example of reference local files

In [None]:
run = wandb.init(entity=WANDB_ENTITY,
                project=WANDB_PROJECT,
                job_type="upload-references")
artifact = wandb.Artifact(name=f"local-file-references_{YOUR_NAME}", type="reference-dataset")
artifact.add_reference("file:///content/sample_data", checksum=True)
run.log_artifact(artifact)
run.finish()

## **Artifacts Time-to-live (TTL)**

W&B Artifacts supports setting time-to-live policies on each version of an Artifact. The following examples show the use TTL policy in a common Artifact logging workflow. We'll cover:

* Setting a TTL policy when creating an Artifact
* Retroactively setting TTL for a specific Artifact aliases

## Setting TTL on New Artifacts
Below we create two new Artifacts from the colab provided sample_data
- mnist_test.csv
- mnist_train_small.csv

Upload them as artifacts files to artifact of type `mnist_dataset` and assign them a TTL

In [None]:
from datetime import timedelta

run = wandb.init(entity=WANDB_ENTITY,
                project=WANDB_PROJECT,
                job_type="raw-data")

raw_mnist_train = wandb.Artifact(
    f"mnist_train_small_{YOUR_NAME}",
    type="mnist_dataset",
    description="Small MNIST Training Set"
)

raw_mnist_train.add_file("sample_data/mnist_train_small.csv")
raw_mnist_train.ttl = timedelta(days=10)
run.log_artifact(raw_mnist_train, aliases=["small", "mnist", "train"])

raw_mnist_test = wandb.Artifact(
    f"mnist_test_small_{YOUR_NAME}",
    type="mnist_dataset",
    description="Small MNIST Test Set"
)

raw_mnist_test.add_file("sample_data/mnist_test.csv")
raw_mnist_test.ttl = timedelta(days=10)
run.log_artifact(raw_mnist_test, aliases=["small", "mnist", "test"])

run.finish()

# Full pipeline - data upload, model training, linking the best model to registry

The below pipeline includes:

* Data Versioning: The Heart Disease dataset is split into training, validation, and test sets, each logged as a W&B artifact for easy tracking and reproducibility. The training dataset we linked to our dataset registry

* Model Training: A neural network is trained on the training set, with performance monitored on the validation set. The best model version is saved and versioned as a W&B artifact which is then linked to our model registry

* What It Showcases:
How to use W&B for seamless data and model versioning.
Best practices for tracking model performance during training and evaluation.
Efficient and reproducible ML workflows in a production-ready environment.

In [None]:
COLLECTION_NAME = "heart-disease" #@param {type: "string"}

In [None]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from io import StringIO

# Load the Heart Disease dataset
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data"
columns = ["age", "sex", "cp", "trestbps", "chol", "fbs", "restecg",
           "thalach", "exang", "oldpeak", "slope", "ca", "thal", "target"]

In [None]:
# Data download and processing
raw_data = !CURL https://archive.ics.uci.edu/ml/machine-learning-databases/heart-disease/processed.cleveland.data
clean_rows = [
    row
    for row in raw_data
    if row.strip()                  # drop the empty string
       and "--:--:--" not in row    # progress-bar lines
       and "%" not in row           # header / percent lines
       and "Dload" not in row       # second header line
]

csv = '\n'.join(clean_rows)
data = pd.read_csv(StringIO(csv), header=None, names=columns)

# Replace missing values ('?') with NaN and drop rows with NaN values
data.replace('?', np.nan, inplace=True)
data = data.dropna().astype(float)

# Convert target variable: 0 = no heart disease, 1 = presence of heart disease
data['target'] = (data['target'] > 0).astype(int)

# Shuffle the dataset to ensure random distribution
data = data.sample(frac=1, random_state=42).reset_index(drop=True)

# Perform a train/validation/test split (60/20/20)
train_size = int(0.6 * len(data))
val_size = int(0.2 * len(data))
test_size = len(data) - train_size - val_size

train_data = data[:train_size]
val_data = data[train_size:train_size + val_size]
test_data = data[train_size + val_size:]

# Save the entire dataset as a CSV file
data.to_csv("heart_disease_full_dataset.csv", index=False)

The following function is used to save a dataset to a file and log it as an artifact

In [None]:
# Simple function to save log dataset artifacts
def save_and_log_dataset(data, filename, artifact_name, aliases):
    # Save the dataset as a CSV file
    data.to_csv(filename, index=False)

    # Create and log the dataset artifact
    dataset_artifact = wandb.Artifact(name=artifact_name, type='dataset')
    dataset_artifact.add_file(filename)
    wandb.log_artifact(dataset_artifact, aliases=aliases)

    return dataset_artifact

The following cell is going to save our train, validation, and test datasets as artifacts, as well as link our training dataset to our Dataset registry

In [None]:
#Upload data to wandb

run = wandb.init(entity=WANDB_ENTITY,
                project=WANDB_PROJECT,
                group = YOUR_NAME,
                job_type="heart-disease-data-uploads",
                name = f"heart_disease_data_uploads_{YOUR_NAME}",
                tags = ["data-upload"]
                )

# Save and log the entire dataset
full_artifact = save_and_log_dataset(data, "heart_disease_full_dataset.csv", f'heart_disease_full_dataset_{YOUR_NAME}', ["initial_commit", "complete_dataset"])

# Save and log the training dataset
train_artifact = save_and_log_dataset(train_data, "heart_disease_train_dataset.csv", f'heart_disease_train_dataset_{YOUR_NAME}', ["initial_commit", "train_split"])

# Save and log the validation dataset
val_artifact = save_and_log_dataset(val_data, "heart_disease_val_dataset.csv", f'heart_disease_validation_dataset_{YOUR_NAME}', ["initial_commit", "validation_split"])

# Save and log the test dataset
test_artifact = save_and_log_dataset(test_data, "heart_disease_test_dataset.csv", f'heart_disease_test_dataset_{YOUR_NAME}', ["initial_commit", "test_split"])


#Log all dataset to W&B tables for visual analysis
run.log({f"train_data_table_{YOUR_NAME}": wandb.Table(dataframe=train_data),
           f"test_data_table_{YOUR_NAME}": wandb.Table(dataframe=test_data),
           f"validation_data_table_{YOUR_NAME}": wandb.Table(dataframe=val_data)})

# Linking Training Dataset to collection in dataset registry
target_path = f"WandB-Intro-Workshop-Registry-dataset/{COLLECTION_NAME}"

# this line will only work if you are logging to a team entity - which has access to Registry
run.link_artifact(
  artifact=train_artifact,
  target_path= target_path
)

run.finish()

The following cell is going to pull the latest training dataset from our training dataset registry so we can start training on it.

During training, we save our model checkpoints as artifacts, but we're going to promote our best model from our training to our model registry

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [None]:
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

#simple function for loading artifacts from wandb
def load_and_split_data(entity, project, artifact_name, your_name, split_name):

    artifact_full_name = f'{entity}/{project}/{artifact_name}_{your_name}:latest'
    artifact = wandb.use_artifact(artifact_full_name, type='dataset')
    artifact_dir = artifact.download()#If you are using local version of artifact, you can simple utilize wandb.use_artifact() only instead of downloading to associated lineage to the artifact

    data = pd.read_csv(f"{artifact_dir}/heart_disease_{split_name}_dataset.csv")

    X = torch.tensor(data.drop("target", axis=1).values, dtype=torch.float32)
    y = torch.tensor(data["target"].values, dtype=torch.float32)

    return X, y


# Initialize training run
run = wandb.init(entity=WANDB_ENTITY,
                project=WANDB_PROJECT,
                group = YOUR_NAME,
                job_type="heart-disease-training",
                name = f"heart_disease_training_validation_{YOUR_NAME}"
                )


# Explicitly accessing Training data
artifact = run.use_artifact('WandB-Intro-Workshop-Registry-dataset/heart-disease:latest', type='dataset')
artifact_dir = artifact.download()

data = pd.read_csv(f"{artifact_dir}/heart_disease_train_dataset.csv")

X_train = torch.tensor(data.drop("target", axis=1).values, dtype=torch.float32)
y_train = torch.tensor(data["target"].values, dtype=torch.float32)

# Load validation data
X_val, y_val = load_and_split_data(WANDB_ENTITY, WANDB_PROJECT, 'heart_disease_validation_dataset', YOUR_NAME, 'val')

# Define a simple neural network model
class HeartDiseaseModel(nn.Module):
    def __init__(self, input_size):
        super(HeartDiseaseModel, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)
        self.bn1 = nn.BatchNorm1d(128)
        self.fc2 = nn.Linear(128, 64)
        self.bn2 = nn.BatchNorm1d(64)
        self.fc3 = nn.Linear(64, 32)
        self.bn3 = nn.BatchNorm1d(32)
        self.fc4 = nn.Linear(32, 16)
        self.fc5 = nn.Linear(16, 1)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.fc2(x)))
        x = self.dropout(x)
        x = F.relu(self.bn3(self.fc3(x)))
        x = self.dropout(x)
        x = F.relu(self.fc4(x))
        x = torch.sigmoid(self.fc5(x))
        return x

model = HeartDiseaseModel(input_size=X_train.shape[1])
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

best_performance = float('inf')
version = 1

for epoch in range(100):
    model.train()
    optimizer.zero_grad()
    outputs = model(X_train).squeeze()
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    # Calculate and log training accuracy
    predictions = (outputs >= 0.5).float()
    train_accuracy = (predictions == y_train).float().mean().item()

    run.log({"train/epoch": epoch, "train/train_loss": loss.item(), "train/train_accuracy": train_accuracy})

    # Evaluate the model on validation set
    model.eval()
    with torch.no_grad():
        val_outputs = model(X_val).squeeze()
        val_loss = criterion(val_outputs, y_val).item()

        # Calculate and log validation accuracy
        val_predictions = (val_outputs >= 0.5).float()
        val_accuracy = (val_predictions == y_val).float().mean().item()

        run.log({"val/val_loss": val_loss, "val/val_accuracy": val_accuracy})

        if val_loss < best_performance:
            best_performance = val_loss
            model_path = f"heart_disease_model_v{version}.pth"
            torch.save(model.state_dict(), model_path)
            artifact = wandb.Artifact(name=f'heart_disease_model_{YOUR_NAME}', type='model')
            artifact.add_file(model_path)
            wandb.log_artifact(artifact, aliases=[f"v{version}", "best"])
            version += 1

# linking the best model to collection in model registry
target_path = f"WandB-Intro-Workshop-Registry-model/{COLLECTION_NAME}"

# Giving it the production alias since this is our best model
# this line will only work if you are logging to a team entity
run.link_artifact(
  artifact=artifact,
  target_path= target_path,
  aliases=["production"]
)

run.finish()