# Classification with Segment Network
* We flatten a BW image row/col wise with values as positions and see if we can classify with Segment Net


# Initialization

In [23]:
import math
import random
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import time
from tqdm import tqdm

from segment import Segment

In [2]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [3]:
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# MNIST dataset

In [192]:
from torchvision import datasets, transforms

flatten_by_row_transform = transforms.Lambda(lambda x: x.flatten())
flatten_by_col_transform = transforms.Lambda(lambda x: x.permute(0,2,1).flatten())
bw_array_transform = transforms.Lambda(lambda x: torch.where(x < x.mean(), 0., 1.))
position_transform = transforms.Lambda(lambda x: x*torch.arange(1, x.shape[0]+1.)/x.shape[0] )

train_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    bw_array_transform,
    flatten_by_row_transform,
    #position_transform
])

test_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    bw_array_transform,
    flatten_by_row_transform
    #position_transform
])

train_set = datasets.MNIST('data', train=True, download=False, transform=train_transform)
test_set = datasets.MNIST('data', train=False, download=False, transform=test_transform)
print(len(train_set), len(test_set))
print(train_set[0][0].shape, test_set[0][0].shape)



60000 10000
torch.Size([784]) torch.Size([784])


# Model Definition
* Start with single layer Segment with output = number of labels
* Then increase number of layers to see if loss reduces 
* Then try transposing input x and batch size and see if it trains better

# Model Training

In [73]:
class SimpleModel(nn.Module):
    def __init__(self, input_size, output_size, segment_size):
        super().__init__()
        self.seg1 = Segment(input_size, output_size, segment_size)
        self.init = False

    def forward(self, x):
        if self.training and not self.init:
            self.seg1.custom_init(torch.zeros_like(x.min(dim=0).values), torch.ones_like(x.max(dim=0).values))
            self.init = True
        x = self.seg1(x)
        return x

In [160]:
#Permute input to seg1.
#Seg2 input should be the parameters from seg1 (not output)
class MyModel(nn.Module):
    def __init__(self, input_size, batch_size, segment1_size, segment2_size, output_dim):
        super().__init__()
        self.init = False
        self.seg1 = Segment(1, batch_size, segment1_size)
        seg2_output_size = (segment1_size+1)*2
        self.x_in = None
        self.seg2 = Segment(input_size, seg2_output_size, int(seg2_output_size/4))
        self.seg3 = Segment(seg2_output_size, output_dim, segment2_size)
        
    def custom_init(self, x):
        x_min = torch.zeros(self.seg1.in_features, dtype=x.dtype)
        x_max = torch.ones(self.seg1.in_features, dtype=x.dtype)
        self.seg1.custom_init(x_min, x_max)
        x_min = torch.zeros(self.seg2.in_features, dtype=x.dtype)
        x_max = torch.ones(self.seg2.in_features, dtype=x.dtype)
        self.seg2.custom_init(x_min, x_max)
        x_min = torch.zeros(self.seg3.in_features, dtype=x.dtype)
        x_max = torch.ones(self.seg3.in_features, dtype=x.dtype)
        self.seg3.custom_init(x_min, x_max)
        
        self.init = True

    def forward(self, x):
        if self.training:
            if not self.init:
                self.custom_init(x)
            
            #Train a model1 that predicts x,y for the input x_in so output is x
            self.x_in = torch.arange(1, x.shape[1]+1.)/x.shape[1]
            self.x_in.unsqueeze_(-1)
            y1 = self.seg1(self.x_in)
            loss1 = F.mse_loss(y1, x.permute(1,0))
            seg1_params = torch.cat((self.seg1.x, self.seg1.y), dim=1)
            # reshape
            seg1_params = seg1_params.view(seg1_params.shape[2], seg1_params.shape[1])
            #train second model that predicts model1 x,y parameters for the input x.
            #these are equivalent to embeddings for the input x
        
        self.embeddings = self.seg2(x)
        
        if self.training:
            loss2 = F.mse_loss(self.embeddings, seg1_params)
        
        ypred = self.seg3(self.embeddings)

        if self.training:
            return loss1, loss2, ypred
        else:
            return ypred
    

In [193]:
# Define Model : 1 input, 1 output, play with segments starting from 1/2 of image pixels.
BATCH_SIZE=64

#model = SimpleModel(784,10,10)
model = MyModel(784, 64, 14, 10, 10)
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

criterion = nn.CrossEntropyLoss()

lr=1e-4
#optimizer only optimizes parameter that are sent to it in arg1
optimizer = torch.optim.AdamW(model.parameters(),
                               lr=lr, betas=(0.9, 0.999), eps=1e-8)

Total parameters: 384840


In [187]:
#model = MyModel(784, 64, 14, 10, 10)
#total_params = sum(p.numel() for p in model.parameters())
#print(f"Total parameters: {total_params}")
#loss1, loss2, ypred = model(img)
#loss = criterion(y_pred, label) + loss1 + loss2

In [194]:
num_epochs = 40

train_dataloader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True)

start = time.time()
for epoch in range(1,num_epochs+1):
    model.train()
    t0 = time.time()
    train_running_loss = 0
    step = 0
    for data in tqdm(train_dataloader, position=0, leave=True):
        img, label = data

        #don't train for wrong batch size for now
        if img.shape[0] != BATCH_SIZE:
            continue

        # ===================forward=====================
        loss1, loss2, y_pred = model(img)
        y_pred_label = torch.argmax(y_pred, dim=1)
        loss = criterion(y_pred, label) + loss1 + loss2
        if(math.isnan(loss)):
            print(f"nan loss at step {step}")
            break
        # ===================backward====================
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_running_loss += loss.item()
        step += 1
    # ===================log========================
    t1 = time.time()
    train_loss = train_running_loss / len(train_dataloader)
    
    if(math.isnan(loss)):
        print(f"nan loss ")
        break

    model.eval()
    val_labels = []
    val_preds = []
    val_running_loss = 0
    with torch.no_grad():
        for data in tqdm(test_dataloader, position=0, leave=True):
            img, label = data
            y_pred = model(img)
            y_pred_label = torch.argmax(y_pred, dim=1)
            
            val_labels.extend(label.cpu().detach())
            val_preds.extend(y_pred_label.cpu().detach())
            
            loss = criterion(y_pred, label)
            val_running_loss += loss.item()
    val_loss = val_running_loss/len(test_dataloader)

    print("-"*30)
    print(f"Train Loss EPOCH {epoch}: {train_loss:.4f}")
    print(f"Valid Loss EPOCH {epoch}: {val_loss:.4f}")
    #print(f"Train Accuracy EPOCH {epoch+1}: {sum(1 for x,y in zip(train_preds, train_labels) if x == y) / len(train_labels):.4f}")
    print(f"Valid Accuracy EPOCH {epoch}: {sum(1 for x,y in zip(val_preds, val_labels) if x == y) / len(val_labels):.4f}")
    print("-"*30)

stop = time.time()
print(f"Training Time: {stop-start:.2f}s")

100%|██████████| 938/938 [00:57<00:00, 16.33it/s]
100%|██████████| 157/157 [00:04<00:00, 34.87it/s]


------------------------------
Train Loss EPOCH 1: 1.9211
Valid Loss EPOCH 1: 0.9658
Valid Accuracy EPOCH 1: 0.8171
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.55it/s]
100%|██████████| 157/157 [00:04<00:00, 35.20it/s]


------------------------------
Train Loss EPOCH 2: 0.9239
Valid Loss EPOCH 2: 0.4259
Valid Accuracy EPOCH 2: 0.8944
------------------------------


100%|██████████| 938/938 [00:57<00:00, 16.38it/s]
100%|██████████| 157/157 [00:04<00:00, 34.84it/s]


------------------------------
Train Loss EPOCH 3: 0.6812
Valid Loss EPOCH 3: 0.3144
Valid Accuracy EPOCH 3: 0.9179
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.46it/s]
100%|██████████| 157/157 [00:04<00:00, 34.90it/s]


------------------------------
Train Loss EPOCH 4: 0.5905
Valid Loss EPOCH 4: 0.2838
Valid Accuracy EPOCH 4: 0.9209
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.49it/s]
100%|██████████| 157/157 [00:04<00:00, 34.61it/s]


------------------------------
Train Loss EPOCH 5: 0.5440
Valid Loss EPOCH 5: 0.2431
Valid Accuracy EPOCH 5: 0.9347
------------------------------


100%|██████████| 938/938 [00:57<00:00, 16.33it/s]
100%|██████████| 157/157 [00:04<00:00, 34.03it/s]


------------------------------
Train Loss EPOCH 6: 0.5126
Valid Loss EPOCH 6: 0.2156
Valid Accuracy EPOCH 6: 0.9420
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.46it/s]
100%|██████████| 157/157 [00:04<00:00, 33.78it/s]


------------------------------
Train Loss EPOCH 7: 0.4931
Valid Loss EPOCH 7: 0.2191
Valid Accuracy EPOCH 7: 0.9394
------------------------------


100%|██████████| 938/938 [00:57<00:00, 16.43it/s]
100%|██████████| 157/157 [00:04<00:00, 34.72it/s]


------------------------------
Train Loss EPOCH 8: 0.4775
Valid Loss EPOCH 8: 0.2057
Valid Accuracy EPOCH 8: 0.9410
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.49it/s]
100%|██████████| 157/157 [00:04<00:00, 34.54it/s]


------------------------------
Train Loss EPOCH 9: 0.4614
Valid Loss EPOCH 9: 0.1936
Valid Accuracy EPOCH 9: 0.9433
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.66it/s]
100%|██████████| 157/157 [00:04<00:00, 35.25it/s]


------------------------------
Train Loss EPOCH 10: 0.4497
Valid Loss EPOCH 10: 0.2006
Valid Accuracy EPOCH 10: 0.9398
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.51it/s]
100%|██████████| 157/157 [00:04<00:00, 35.33it/s]


------------------------------
Train Loss EPOCH 11: 0.4394
Valid Loss EPOCH 11: 0.1912
Valid Accuracy EPOCH 11: 0.9420
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.69it/s]
100%|██████████| 157/157 [00:04<00:00, 34.45it/s]


------------------------------
Train Loss EPOCH 12: 0.4296
Valid Loss EPOCH 12: 0.1701
Valid Accuracy EPOCH 12: 0.9500
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.54it/s]
100%|██████████| 157/157 [00:04<00:00, 37.12it/s]


------------------------------
Train Loss EPOCH 13: 0.4272
Valid Loss EPOCH 13: 0.1667
Valid Accuracy EPOCH 13: 0.9503
------------------------------


100%|██████████| 938/938 [00:54<00:00, 17.09it/s]
100%|██████████| 157/157 [00:04<00:00, 37.72it/s]


------------------------------
Train Loss EPOCH 14: 0.4194
Valid Loss EPOCH 14: 0.1804
Valid Accuracy EPOCH 14: 0.9457
------------------------------


100%|██████████| 938/938 [00:54<00:00, 17.11it/s]
100%|██████████| 157/157 [00:04<00:00, 36.84it/s]


------------------------------
Train Loss EPOCH 15: 0.4145
Valid Loss EPOCH 15: 0.1616
Valid Accuracy EPOCH 15: 0.9530
------------------------------


100%|██████████| 938/938 [00:54<00:00, 17.10it/s]
100%|██████████| 157/157 [00:04<00:00, 36.63it/s]


------------------------------
Train Loss EPOCH 16: 0.4079
Valid Loss EPOCH 16: 0.1643
Valid Accuracy EPOCH 16: 0.9498
------------------------------


100%|██████████| 938/938 [00:55<00:00, 16.98it/s]
100%|██████████| 157/157 [00:04<00:00, 36.86it/s]


------------------------------
Train Loss EPOCH 17: 0.4027
Valid Loss EPOCH 17: 0.1752
Valid Accuracy EPOCH 17: 0.9471
------------------------------


100%|██████████| 938/938 [00:55<00:00, 17.00it/s]
100%|██████████| 157/157 [00:04<00:00, 36.59it/s]


------------------------------
Train Loss EPOCH 18: 0.3984
Valid Loss EPOCH 18: 0.1493
Valid Accuracy EPOCH 18: 0.9539
------------------------------


100%|██████████| 938/938 [00:54<00:00, 17.21it/s]
100%|██████████| 157/157 [00:04<00:00, 37.55it/s]


------------------------------
Train Loss EPOCH 19: 0.3928
Valid Loss EPOCH 19: 0.1548
Valid Accuracy EPOCH 19: 0.9521
------------------------------


100%|██████████| 938/938 [00:53<00:00, 17.64it/s]
100%|██████████| 157/157 [00:04<00:00, 36.36it/s]


------------------------------
Train Loss EPOCH 20: 0.3931
Valid Loss EPOCH 20: 0.1727
Valid Accuracy EPOCH 20: 0.9470
------------------------------


100%|██████████| 938/938 [00:54<00:00, 17.33it/s]
100%|██████████| 157/157 [00:04<00:00, 37.68it/s]


------------------------------
Train Loss EPOCH 21: 0.3851
Valid Loss EPOCH 21: 0.1668
Valid Accuracy EPOCH 21: 0.9489
------------------------------


100%|██████████| 938/938 [00:53<00:00, 17.53it/s]
100%|██████████| 157/157 [00:04<00:00, 37.85it/s]


------------------------------
Train Loss EPOCH 22: 0.3869
Valid Loss EPOCH 22: 0.1566
Valid Accuracy EPOCH 22: 0.9518
------------------------------


100%|██████████| 938/938 [00:53<00:00, 17.58it/s]
100%|██████████| 157/157 [00:04<00:00, 37.96it/s]


------------------------------
Train Loss EPOCH 23: 0.3831
Valid Loss EPOCH 23: 0.1769
Valid Accuracy EPOCH 23: 0.9465
------------------------------


100%|██████████| 938/938 [00:53<00:00, 17.56it/s]
100%|██████████| 157/157 [00:04<00:00, 36.96it/s]


------------------------------
Train Loss EPOCH 24: 0.3772
Valid Loss EPOCH 24: 0.1579
Valid Accuracy EPOCH 24: 0.9530
------------------------------


100%|██████████| 938/938 [32:54<00:00,  2.10s/it]   
100%|██████████| 157/157 [00:04<00:00, 37.27it/s]


------------------------------
Train Loss EPOCH 25: 0.3776
Valid Loss EPOCH 25: 0.1762
Valid Accuracy EPOCH 25: 0.9426
------------------------------


100%|██████████| 938/938 [14:10<00:00,  1.10it/s]   
100%|██████████| 157/157 [00:04<00:00, 33.62it/s]


------------------------------
Train Loss EPOCH 26: 0.3739
Valid Loss EPOCH 26: 0.1614
Valid Accuracy EPOCH 26: 0.9489
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.68it/s]
100%|██████████| 157/157 [00:04<00:00, 35.70it/s]


------------------------------
Train Loss EPOCH 27: 0.3725
Valid Loss EPOCH 27: 0.1433
Valid Accuracy EPOCH 27: 0.9553
------------------------------


100%|██████████| 938/938 [00:55<00:00, 16.99it/s]
100%|██████████| 157/157 [00:04<00:00, 36.75it/s]


------------------------------
Train Loss EPOCH 28: 0.3687
Valid Loss EPOCH 28: 0.1588
Valid Accuracy EPOCH 28: 0.9493
------------------------------


100%|██████████| 938/938 [00:55<00:00, 16.79it/s]
100%|██████████| 157/157 [00:04<00:00, 35.93it/s]


------------------------------
Train Loss EPOCH 29: 0.3632
Valid Loss EPOCH 29: 0.1601
Valid Accuracy EPOCH 29: 0.9487
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.72it/s]
100%|██████████| 157/157 [00:04<00:00, 35.92it/s]


------------------------------
Train Loss EPOCH 30: 0.3667
Valid Loss EPOCH 30: 0.1461
Valid Accuracy EPOCH 30: 0.9539
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.72it/s]
100%|██████████| 157/157 [00:04<00:00, 36.09it/s]


------------------------------
Train Loss EPOCH 31: 0.3641
Valid Loss EPOCH 31: 0.1479
Valid Accuracy EPOCH 31: 0.9554
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.62it/s]
100%|██████████| 157/157 [00:04<00:00, 36.37it/s]


------------------------------
Train Loss EPOCH 32: 0.3617
Valid Loss EPOCH 32: 0.1392
Valid Accuracy EPOCH 32: 0.9562
------------------------------


100%|██████████| 938/938 [00:55<00:00, 16.95it/s]
100%|██████████| 157/157 [00:04<00:00, 35.37it/s]


------------------------------
Train Loss EPOCH 33: 0.3591
Valid Loss EPOCH 33: 0.1906
Valid Accuracy EPOCH 33: 0.9379
------------------------------


100%|██████████| 938/938 [00:54<00:00, 17.09it/s]
100%|██████████| 157/157 [00:04<00:00, 36.31it/s]


------------------------------
Train Loss EPOCH 34: 0.3590
Valid Loss EPOCH 34: 0.1412
Valid Accuracy EPOCH 34: 0.9562
------------------------------


100%|██████████| 938/938 [00:54<00:00, 17.07it/s]
100%|██████████| 157/157 [00:04<00:00, 34.70it/s]


------------------------------
Train Loss EPOCH 35: 0.3580
Valid Loss EPOCH 35: 0.1665
Valid Accuracy EPOCH 35: 0.9494
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.52it/s]
100%|██████████| 157/157 [00:04<00:00, 34.68it/s]


------------------------------
Train Loss EPOCH 36: 0.3550
Valid Loss EPOCH 36: 0.1502
Valid Accuracy EPOCH 36: 0.9514
------------------------------


100%|██████████| 938/938 [00:55<00:00, 16.99it/s]
100%|██████████| 157/157 [00:04<00:00, 36.08it/s]


------------------------------
Train Loss EPOCH 37: 0.3492
Valid Loss EPOCH 37: 0.1425
Valid Accuracy EPOCH 37: 0.9572
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.67it/s]
100%|██████████| 157/157 [00:04<00:00, 34.55it/s]


------------------------------
Train Loss EPOCH 38: 0.3536
Valid Loss EPOCH 38: 0.1573
Valid Accuracy EPOCH 38: 0.9528
------------------------------


100%|██████████| 938/938 [00:56<00:00, 16.54it/s]
100%|██████████| 157/157 [00:04<00:00, 36.18it/s]


------------------------------
Train Loss EPOCH 39: 0.3522
Valid Loss EPOCH 39: 0.1377
Valid Accuracy EPOCH 39: 0.9562
------------------------------


100%|██████████| 938/938 [00:55<00:00, 16.99it/s]
100%|██████████| 157/157 [00:04<00:00, 35.80it/s]

------------------------------
Train Loss EPOCH 40: 0.3507
Valid Loss EPOCH 40: 0.1474
Valid Accuracy EPOCH 40: 0.9546
------------------------------
Training Time: 5120.11s





In [183]:
# Save graph to a file
#!pip install torchviz
from torchviz import make_dot

# Generate a Graphviz object from the computation graph
graph = make_dot(loss, params=dict(model.named_parameters())) 

# Save the graph as a PDF or any other format if needed
graph.render("model_Classification_Segment_v1_graph")

'model_Classification_Segment_v1_graph.pdf'

# Experiment Observations
* model=SimpleModel(784,10,8), model_params=141120, batch_size=64, lr=1e-5, epochs=40, test_acc=0.9118
* model=SimpleModel(784,10,10), model_params=172480, batch_size=64, lr=1e-4, epochs=40, test_acc=0.9239

### Turned off bw_transform and position transform. feeding flatten grayscape to above model
* model=SimpleModel(784,10,10), model_params=172480, batch_size=64, lr=1e-4, epochs=40, test_acc=0.9310 (max 0.935)
* so we are not really learning from shape

### Trained a new model that uses multiple segment nets and tries to predict x,y
*  MyModel(784, 64, 14, 10, 10), model_params=384840, batch_size=64, lr=1e-4, epochs=20, test_acc=0.9231 (max 0.9355)

### Turned ON bw_transform (no position transform)
*  MyModel(784, 64, 14, 10, 10), model_params=384840, batch_size=64, lr=1e-4, epochs=40, test_acc=0.9546 (max 0.9572)
