In [12]:
from torchvision import datasets
from torchvision.transforms import ToTensor
import torch
from torch import nn
root='../dataset'


train_data = datasets.MNIST( root=root,
                                    train="True",
                                    transform=ToTensor(),
                                    download=True)
test_data = datasets.MNIST(root=root,
                                   train=False,
                                   transform=ToTensor())   




In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
train_dataloader = torch.utils.data.DataLoader(
                train_data, 
                batch_size =32,
                shuffle=True,
                num_workers=1)
test_dataloader = torch.utils.data.DataLoader(
                test_data,
                batch_size=32,
                shuffle=False,
                num_workers=1)

In [15]:
from torch.utils.data import DataLoader

# Setup the batch size hyperparameter
BATCH_SIZE = 32


# Let's check out what we've created
print(f"Dataloaders: {train_dataloader, test_dataloader}") 
print(f"Length of train dataloader: {len(train_dataloader)} batches of {BATCH_SIZE}")
print(f"Length of test dataloader: {len(test_dataloader)} batches of {BATCH_SIZE}")

# Check out what's inside the training dataloader
train_features_batch, train_labels_batch = next(iter(train_dataloader))
train_features_batch.shape, train_labels_batch.shape

Dataloaders: (<torch.utils.data.dataloader.DataLoader object at 0x0000015FDCC0F790>, <torch.utils.data.dataloader.DataLoader object at 0x0000015FDCC0FA00>)
Length of train dataloader: 1875 batches of 32
Length of test dataloader: 313 batches of 32


(torch.Size([32, 1, 28, 28]), torch.Size([32]))

In [16]:
import sys
sys.path.insert(0, "..")
from source import models

In [17]:
model = models.CBM(c_chanel=1, 
                   hidden_units=32, 
                   concepts=112, 
                   output_shape=10,
                   img_shape=28).to(device)

In [18]:
import requests
from pathlib import Path 

# Download helper functions from Learn PyTorch repo (if not already downloaded)
if Path("helper_functions.py").is_file():
  print("helper_functions.py already exists, skipping download")
else:
  print("Downloading helper_functions.py")
  # Note: you need the "raw" GitHub URL for this to work
  request = requests.get("https://raw.githubusercontent.com/mrdbourke/pytorch-deep-learning/main/helper_functions.py")
  with open("helper_functions.py", "wb") as f:
    f.write(request.content)

helper_functions.py already exists, skipping download


In [19]:
##loss function and optimizer

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=0.1)

In [20]:
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

In [21]:
#training loop

from tqdm.auto import tqdm
from timeit import default_timer as timer
from helper_functions import accuracy_fn

torch.manual_seed=42
torch.cuda.manual_seed=42
#train_time_start_on_cpu = timer()
epochs =1 
flag =0 
for epoch in tqdm(range(epochs)):
    print(f"Epoch: {epoch}\n-----")
    train_loss=0
    for batch, (X,y) in enumerate(train_dataloader):
        model.train()
        X, y = X.to(device), y.to(device)
        if not flag:
            print(f"X shape: {X.shape} y shape: {y.shape}")
            flag=1
        y_logits = model(X)
        #print(f"Logit shape: {y_logits.shape}, Y shape: {y.shape}")
        loss = loss_fn(y_logits, y)
        train_loss+=loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if batch%600==0:
            print(f"Looked at {batch*len(X)}/{len(train_dataloader.dataset)} samples\n")
    train_loss /= len(train_dataloader)

    #testing loop 

    test_loss, test_acc =0.0, 0.0
    model.eval()
    with torch.inference_mode():
        model.fc1.register_forward_hook(get_activation('fc1'))
        for X,y in test_dataloader:
            X, y = X.to(device), y.to(device)
            test_pred = model(X)
            test_loss += loss_fn(test_pred, y)
            test_acc = accuracy_fn(y_true=y, y_pred=test_pred.argmax(dim=1))
            concepts = activation['fc1']
            concepts_idx = concepts.argmax(dim=1)
            print(f"Length: {len(concepts)} and : {concepts_idx}\n")

        test_loss /= len(test_dataloader)
        test_acc /= len(test_dataloader)
print(f"\nTrain loss: {train_loss:.5f} | Test loss: {test_loss:.5f}, Test acc: {test_acc:.2f}%\n")    
        

  0%|          | 0/1 [00:00<?, ?it/s]

Epoch: 0
-----
X shape: torch.Size([32, 1, 28, 28]) y shape: torch.Size([32])
Looked at 0/60000 samples

Looked at 19200/60000 samples

Looked at 38400/60000 samples

Looked at 57600/60000 samples

Length: 32 and : tensor([ 7, 21, 20, 21, 11, 20, 10,  0, 17, 13, 21, 10, 29, 21, 22,  3, 13,  7,
         3, 10, 29, 10, 10,  3, 29, 21,  7, 29, 21,  0, 22, 20],
       device='cuda:0')

Length: 32 and : tensor([22, 11, 31, 21, 21, 20, 21, 20, 22,  7, 10,  4, 22,  3,  0,  4, 10, 11,
        10, 22,  3,  3, 10, 21, 10, 20, 29, 17,  7,  4, 17,  4],
       device='cuda:0')

Length: 32 and : tensor([31, 20, 10, 11, 13, 21,  7, 21,  4, 13, 20, 31, 22, 21, 13, 31,  7, 10,
         4,  7, 10, 10,  7,  6, 10, 20, 22, 10, 13,  9,  0, 10],
       device='cuda:0')

Length: 32 and : tensor([ 0, 31, 10, 13, 10, 21,  3, 26, 13, 13,  4, 20, 13, 10,  0, 31, 22, 13,
        31, 10, 10, 10, 29, 21,  3, 10,  7, 10,  7, 11,  9,  9],
       device='cuda:0')

Length: 32 and : tensor([ 0,  3, 11, 11,  3,  4,  0, 2

100%|██████████| 1/1 [00:25<00:00, 25.50s/it]


Train loss: 0.24647 | Test loss: 0.05363, Test acc: 0.32%






In [22]:
model.eval()
with torch.inference_mode():
    for X,y in test_dataloader:
        X, y = X.to(device), y.to(device)
        test_pred = model.forward_until_concepts(X)
        print(test_pred.shape)

torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size([32, 112])
torch.Size