### importing from other files and defining constants

In [None]:
from mns_dataloaders import two_concat_biased,two_concat,three_channels,five_channels,five_concat,four_gathered
from mns_models import five_concat_Net,two_concat_Net,three_channels_Net,five_channels_Net,four_gathered_Net
import torch

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

### defining train and test functions for different tasks

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

In [None]:
def train_two_concat(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)# Compute prediction error
        pred = model(X)
        # loss = loss_fn(pred[:,:10], y[:,0])+loss_fn(pred[:,10:20], y[:,1])+loss_fn(pred[:,20:30], y[:,2])+loss_fn(pred[:,30:], y[:,3])
        loss = loss_fn(pred[:,30:], y[:,3]) # replace by this to learn just from last digit
        # Backpropagation
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            # print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test_two_concat(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred[:,-10:], y[:,-1]).item()
            correct += (pred[:,-10:].argmax(1) == y[:,-1]).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
    return correct

### train the networks and compute accuracy

In [None]:
accuracy_array = [[] for i in range(3)]
epochs = 100
for task in range(3):
    match task:
        case 0:            
            train_dataloader,test_dataloader = three_channels()
            model = three_channels_Net().to(device)
        case 1:            
            train_dataloader,test_dataloader = two_concat()
            model = two_concat_Net().to(device)
        case 2:            
            train_dataloader,test_dataloader = two_concat_biased()
            model = two_concat_Net().to(device)
        
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    if task ==0:
        for t in range(epochs):
            print(f"Epoch {t+1}\n-------------------------------")
            train(train_dataloader, model, loss_fn, optimizer)
            accuracy = test(test_dataloader, model, loss_fn)
            accuracy_array[task]+=[accuracy]
        print("Done!")
    else:
        for t in range(epochs):
            print(f"Epoch {t+1}\n-------------------------------")
            train_two_concat(train_dataloader, model, loss_fn, optimizer)
            accuracy = test_two_concat(test_dataloader, model, loss_fn)
            accuracy_array[task]+=[accuracy]
        print("Done!")

### implement graph for three digit task

In [None]:
import matplotlib.pyplot as plt
x = list(range(len(accuracy_array[0])))
y1 = accuracy_array[0]
y2 = accuracy_array[1]
y3 = accuracy_array[2]

plt.plot(x, y1,label = "three channels")
plt.plot(x, y2,label = "two concat")
plt.plot(x, y3,label = "two concat biased")

plt.xlabel("epoch number")
plt.ylabel("Accuracy (on 0-1 scale)")
plt.title('accuracy of across epochs (each train epoch has 100000 images)')
plt.legend()
plt.show()