In [1]:
import torch
import matplotlib.pyplot as plt

In [36]:
def train_epoch(model,data_loader,loss_fn,optimiser,device):
    model.train()
    total_loss = 0
    i = 0
    for (X,Y) in data_loader:
        X,Y = X.to(device),Y.to(device)
        predictions = model(X)
        loss = loss_fn(predictions,Y)
        if torch.isnan(loss):
            return None
        optimiser.zero_grad()
        loss.backward()
        optimiser.step()
        total_loss += loss.item()
        if i % 5 == 0:
            print("|",end="")
        i += 1
    cost = total_loss / len(data_loader)
    print(f"\t\tCost : {cost}",end="")
    return cost

In [37]:
def predict(model,input):
    model.eval()
    with torch.no_grad():
        A = model(input)
        predictions = A.argmax(dim=1)
    return predictions

In [38]:
def getAccuracy(predictions,Y):
    return torch.mean((predictions == Y).type(torch.float)).item() * 100

In [39]:
def test(model,data_loader,device,testing=True):
    accuracy_total = 0
    for (X,Y) in data_loader:
        X,Y = X.to(device),Y.to(device)
        predictions = predict(model,X)
        accuracy_total += getAccuracy(predictions,Y)
    accuracy = accuracy_total / len(data_loader)
    print(f"\t{'test' if testing else 'train'} accuracy : {accuracy:.2f}%",end="")
    return accuracy

In [40]:
def train(model,train_data_loader,loss_fn,optimiser,device,epochs,test_data_loader=None):
    costs = []
    training_accs = []
    testing_accs = []
    for i in range(epochs):
        print(f"Epoch {i+1}")
        cost = train_epoch(model,train_data_loader,loss_fn,optimiser,device)
        if cost == None :
            print("Cost became NaN and model crashed")
            break;
        train_accuracy = test(model,train_data_loader,device,testing=False)
        if test_data_loader != None:
            test_accuracy = test(model,test_data_loader,device,testing=True)
            testing_accs.append(test_accuracy)
        costs.append(cost)
        training_accs.append(train_accuracy)
        print("\n--------------------------")
    print("Done")
    return (costs,training_accs,testing_accs)

In [41]:
def plot_costs(costs):
    epochs = list(range(1,len(costs)+1))
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, costs, marker='o', linestyle='-', color='b')
    plt.title('Cost vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Cost')
    plt.grid(True)
    plt.show()

In [42]:
def plot_accuracies(train_accs,test_accs):
    epochs = list(range(1,len(train_accs)+1))
    plt.figure(figsize=(10, 6))
    plt.plot(epochs, train_accs, marker='o', linestyle='-', color='b', label='Training Accuracy')
    plt.plot(epochs, test_accs, marker='o', linestyle='--', color='g', label='Testing Accuracy')
    plt.title('Training and Testing Accuracy vs Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.ylim(0, 100)
    plt.legend()
    plt.grid(True)
    plt.show()