In [1]:
import os
from types import SimpleNamespace
from pathlib import Path
import pandas as pd

from PIL import Image
import timm

import wandb
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader
import torchvision.transforms as T

from torcheval.metrics import Mean, BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score

import params

In [2]:
PROJECT_NAME = "pis"

In [3]:
# defaults
cfg = SimpleNamespace(
    img_size = 256,
    target_column = 'mold',
    bs = 16,
    seed = 42,
    epochs = 2,
    lr = 2e-3,
    wd=1e-5,
    arch = 'resnet18',
    log_model = False,
    PROJECT_NAME = params.PROJECT_NAME,
    ENTITY = params.ENTITY,
    PROCESSED_DATA_AT = f'{params.DATA_AT}:latest',
)

In [4]:
def prepare_data(PROCESSED_DATA_AT):
    "Get/Download the datasets"
    processed_data_at = wandb.use_artifact(PROCESSED_DATA_AT)
    processed_dataset_dir = Path(processed_data_at.download())
    df = pd.read_csv(processed_dataset_dir / 'data_split.csv')
    df = df[df.stage != 'test'].reset_index(drop=True)
    df['valid'] = df.stage == 'valid'
    return df, processed_dataset_dir

In [5]:
with wandb.init(project=PROJECT_NAME):
    df, processed_dataset_dir = prepare_data(cfg.PROCESSED_DATA_AT)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mcapecape[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Downloading large artifact lemon_data:latest, 137.77MB. 2692 files... 
[34m[1mwandb[0m:   2692 of 2692 files downloaded.  
Done. 0:0:0.4


In [6]:
df.head()

Unnamed: 0,image_id,category_id,mold,file_name,fruit_id,fold,stage,valid
0,0,"[9, 5]",False,images/0001_A_H_0_A.jpg,1,3,train,False
1,100,"[2, 5, 7]",False,images/0003_A_V_150_A.jpg,3,7,train,False
2,101,"[9, 2, 5]",False,images/0003_A_V_15_A.jpg,3,7,train,False
3,102,"[2, 5, 7]",False,images/0003_A_V_165_A.jpg,3,7,train,False
4,103,"[9, 5]",False,images/0003_A_V_30_A.jpg,3,7,train,False


In [7]:
class ImageDataset:
    def __init__(self, dataframe, root_dir, transform=None, image_column='file_name', target_column='mold'):
        """
        Args:
            dataframe (pandas.DataFrame): DataFrame containing image filenames and labels.
            root_dir (string): Directory containing the images.
            transform (callable, optional): Optional transform to be applied on an image sample.
            image_column (string, optional): Name of the column containing the image filenames.
            target_column (string, optional): Name of the column containing the labels.
        """
        self.dataframe = dataframe
        self.root_dir = root_dir
        self.transform = transform
        self.image_column = image_column
        self.target_column = target_column
    
    def __len__(self):
        return len(self.dataframe)
    
    def loc(self, idx):
        idx_of_image_column = self.dataframe.columns.get_loc(self.image_column)
        idx_of_target_column = self.dataframe.columns.get_loc(self.target_column)
        x = self.dataframe.iloc[idx, idx_of_image_column]
        y = self.dataframe.iloc[idx, idx_of_target_column]
        return x, y

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name, label = self.loc(idx)
        img_path = os.path.join(self.root_dir, img_name)
        image = Image.open(img_path)

        if self.transform:
            image = self.transform(image)

        return image, 1.0 if label else 0.

In [8]:
tfms = T.Compose([T.Resize(cfg.img_size), T.ToTensor()])

train_ds = ImageDataset(df[~df.valid], processed_dataset_dir, transform=tfms)
valid_ds = ImageDataset(df[df.valid], processed_dataset_dir, transform=tfms)
len(train_ds), len(valid_ds)

(2275, 210)

In [9]:
x,y = train_ds[0]
type(x), type(y)

(torch.Tensor, float)

In [10]:
train_dataloader = DataLoader(train_ds, batch_size=cfg.bs, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_ds, batch_size=cfg.bs, shuffle=False, num_workers=4)

In [11]:
x,y = next(iter(train_dataloader))
type(x), type(y), x.shape, y.shape

(torch.Tensor, torch.Tensor, torch.Size([16, 3, 256, 256]), torch.Size([16]))

In [12]:
model = timm.create_model(cfg.arch, pretrained=False, num_classes=1)

In [13]:
out = model(x)
loss_func = nn.BCEWithLogitsLoss()
loss = loss_func(out.squeeze(), y.squeeze().float())


In [28]:
from fastprogress import progress_bar
from utils import PredsLogger, set_seed, to_device, model_size, get_class_name_in_snake_case as snake_case

class ClassificationTrainer:
    def __init__(self, train_dataloader, valid_dataloader,  model, metrics, device="cuda"):
        
        self.device = torch.device(device)        
        self.model = model.to(self.device)
        self.train_dataloader = train_dataloader
        self.valid_dataloader = valid_dataloader
        self.train_metrics = [m(device=self.device) for m in metrics]
        self.valid_metrics = [m(device=self.device) for m in metrics]
        self.loss = Mean()
    
    def loss_func(self, x, y):
        "A flattened version of nn.BCEWithLogitsLoss"
        loss_func = nn.BCEWithLogitsLoss()
        return loss_func(x.squeeze(), y.squeeze().float())
    
    def compile(self, epochs=5, lr=2e-3, wd=0.01):
        "Keras style compile method"
        self.epochs = epochs
        self.optim = AdamW(self.model.parameters(), lr=lr, weight_decay=wd)
        self.schedule = OneCycleLR(self.optim, 
                                   max_lr=lr, 
                                   pct_start=0.1,
                                   total_steps=epochs*len(self.train_dataloader))

    def reset_metrics(self):
        self.loss.reset()
        for m in self.train_metrics: m.reset()
        for m in self.valid_metrics: m.reset()
        
    def train_step(self, loss):
        self.optim.zero_grad()
        loss.backward()
        self.optim.step()
        self.schedule.step()
        return loss
        
    def one_epoch(self, train=True):
        if train: 
            self.model.train()
            dl = self.train_dataloader
        else: 
            self.model.eval()
            dl = self.valid_dataloader
        pbar = progress_bar(dl, leave=False)
        preds = []
        for b in pbar:
            with (torch.inference_mode() if not train else torch.enable_grad()):
                images, labels = to_device(b, self.device)
                preds_b = self.model(images).squeeze()
                loss = self.loss_func(preds_b, labels)
                self.loss.update(loss.detach().cpu(), weight=len(images))
                preds.append(preds_b)
                if train:
                    self.train_step(loss)
                    for m in self.train_metrics:
                        m.update(preds_b, labels.long())
                    wandb.log({"train_loss": loss.item(),
                               "learning_rate": self.schedule.get_last_lr()[0]})
                else:
                    for m in self.valid_metrics:
                        m.update(preds_b, labels.long())
            pbar.comment = f"train_loss={loss.item():2.3f}"      
            
        return torch.cat(preds, dim=0), self.loss.compute()
    
    def log_preds(self):
        if wandb.run is not None:
            preds_logger = PredsLogger(ds=self.valid_ds) 
            print("Logging model predictions on validation data")
            preds, _ = self.get_model_preds()
            preds_logger.log(preds=preds)
            
    def print_metrics(self, epoch, train_loss, val_loss):
        print(f"Epoch {epoch+1}/{self.epochs} - train_loss: {train_loss.item():2.3f} - val_loss: {val_loss.item():2.3f}")
    
    def fit(self, log_preds=False):      
        wandb.log({"model_size":model_size(self.model)})   
        for epoch in progress_bar(range(self.epochs), total=self.epochs, leave=True):
            _, train_loss = self.one_epoch(train=True)
            wandb.log({f"train_{snake_case(m)}": m.compute() for m in self.train_metrics})

                            
            ## validation
            _, val_loss = self.one_epoch(train=False)
            wandb.log({f"valid_{snake_case(m)}": m.compute() for m in self.valid_metrics}, commit=False)
            wandb.log({"valid_loss": val_loss.item()}, commit=False)
            self.print_metrics(epoch, train_loss, val_loss)
            self.reset_metrics()
        if log_preds:
            self.log_preds()

In [29]:
trainer = ClassificationTrainer(train_dataloader, valid_dataloader, model, 
                                metrics=[BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score], device="cuda")

In [30]:
trainer.compile(epochs=cfg.epochs, lr=cfg.lr, wd=cfg.wd)

In [34]:
with wandb.init(project=PROJECT_NAME, config=cfg):
    trainer.fit()

0,1
learning_rate,▂▃▅████████▇▇▇▇▆▆▆▆▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▁▁▁▁▁▁
model_size,▁
train_binary_accuracy,▁█
train_binary_f1_score,▁█
train_binary_precision,▁█
train_binary_recall,▁█
train_loss,▂▂▁▂▁▁██▂▁▁▃▅▁▂▁▃▁▁▃▁▂▁▂▁▃▄▁▁▂▁▂▃▁▁▁▁▁▁▇
valid_binary_accuracy,▁█
valid_binary_f1_score,▁█
valid_binary_precision,█▁

0,1
learning_rate,0.0
model_size,11177025.0
train_binary_accuracy,0.97846
train_binary_f1_score,0.85879
train_binary_precision,0.98675
train_binary_recall,0.7602
train_loss,0.01266
valid_binary_accuracy,0.9
valid_binary_f1_score,0.60377
valid_binary_precision,0.61538


In [None]:
def log_preds(images, model_preds, targets):
    # Initialize Weights and Biases
    wandb.init(project="Your_Project_Name")

    # Create Table
    wandb_table = wandb.Table(columns=["Input Images", "Model Predictions", "Ground Truth"])

    for image, pred, target in zip(images, model_preds, targets):
        wandb_table.add_data(wandb.Image(image), pred, target)

    # Log Table
    wandb.log({"Results": wandb_table})

    # Finish logging and clean up
    wandb.finish()

In [None]:
I can provide you with a Python function to log your model outputs using Weights and Biases. You'll need to install Weights and Biases by running:

```bash
pip install wandb
```

Here is the Python function:

```python
import wandb
from wandb import DataLogger

def log_preds(images, model_preds, targets):
    # Initialize Weights and Biases
    wandb.init(project="Your_Project_Name")

    # Create Table
    wandb_table = wandb.Table(columns=["Input Images", "Model Predictions", "Ground Truth"])

    for image, pred, target in zip(images, model_preds, targets):
        wandb_table.add_data(wandb.Image(image), pred, target)

    # Log Table
    wandb.log({"Results": wandb_table})

    # Finish logging and clean up
    wandb.finish()
```

Replace `Your_Project_Name` with the relevant project name.

This `log_preds` function initializes Weights and Biases, creates a wandb.Table, and iteratively adds the images, model_preds, and targets to the table. Once all data is added to the table, it logs the table and
finishes the Weights and Biases run.