# Model Visualization

This notebook loads the trained UNet model and visualizes its predictions on the Cityscapes validation set.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import os
import sys

# Add src to path so we can import modules
sys.path.append(os.path.abspath('src'))

from datasets import CityscapesDataset, image_transform, mask_transform, remap_mask, convert
from torch.utils.data import DataLoader

## Load Model

We initialize the UNet architecture and load the trained weights.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize model
model = smp.Unet(
    encoder_name="resnet50",
    encoder_weights="imagenet",
    classes=19,
    activation=None,
)

# Load weights
checkpoint_path = "unet.pth"
try:
    state_dict = torch.load(checkpoint_path, map_location=device)
    # If state_dict contains 'model_state_dict', use that, otherwise use state_dict directly
    if 'model_state_dict' in state_dict:
        model.load_state_dict(state_dict['model_state_dict'])
    else:
        model.load_state_dict(state_dict)
    print("Model loaded successfully")
except FileNotFoundError:
    print(f"Error: Checkpoint file '{checkpoint_path}' not found.")
except Exception as e:
    print(f"Error loading model: {e}")

model.to(device)
model.eval();

## Load Dataset

Load the Cityscapes validation dataset.

In [None]:
val_dataset = CityscapesDataset(
    tvt=1, # 1 for validation
    image_transform=image_transform,
    mask_transform=mask_transform
)

print(f"Validation dataset size: {len(val_dataset)}")

## Visualization

Helper function to visualize image, ground truth, and prediction.

In [None]:
import matplotlib.patches as mpatches

CITYSCAPES_CLASSES = {
    0: "road", 1: "sidewalk", 2: "building", 3: "wall", 4: "fence",
    5: "pole", 6: "traffic light", 7: "traffic sign", 8: "vegetation",
    9: "terrain", 10: "sky", 11: "person", 12: "rider", 13: "car",
    14: "truck", 15: "bus", 16: "train", 17: "motorcycle", 18: "bicycle"
}

def visualize_prediction(dataset, index, model, device):
    image, mask = dataset[index]
    
    # Prepare input for model
    input_tensor = image.unsqueeze(0).to(device)
    
    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        prediction = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
    
    # Prepare image for display
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_display = image.permute(1, 2, 0).numpy()
    img_display = std * img_display + mean
    img_display = np.clip(img_display, 0, 1)
    
    # Get colors from jet colormap
    cmap = plt.get_cmap('jet')
    colors = [cmap(i/18.0) for i in range(19)]
    patches = [mpatches.Patch(color=colors[i], label=f"{i}: {CITYSCAPES_CLASSES[i]}") 
               for i in range(19)]
    
    # Plot - 2 rows: Top for images, Bottom for legend
    fig = plt.figure(figsize=(15, 7))
    gs = fig.add_gridspec(2, 3, height_ratios=[5, 1])
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[0, 2])
    ax_legend = fig.add_subplot(gs[1, :])
    
    ax1.imshow(img_display)
    ax1.set_title("Input Image")
    ax1.axis('off')
    
    ax2.imshow(mask.numpy(), cmap='jet', vmin=0, vmax=18)
    ax2.set_title("Ground Truth")
    ax2.axis('off')

    ax3.imshow(prediction, cmap='jet', vmin=0, vmax=18)
    ax3.set_title("Prediction")
    ax3.axis('off')
    
    # Legend
    ax_legend.legend(handles=patches, loc='center', ncol=7, frameon=False)
    ax_legend.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize a few random samples
import random

indices = random.sample(range(len(val_dataset)), 3)

for idx in indices:
    print(f"Visualizing sample {idx}")
    visualize_prediction(val_dataset, idx, model, device)

## Error Visualization

We highlight correct pixels in **green** and incorrect pixels in **red**. Pixels labeled as void (255) are ignored.

In [None]:
def visualize_error(dataset, index, model, device):
    image, mask = dataset[index]
    
    # Prepare input for model
    input_tensor = image.unsqueeze(0).to(device)
    
    # Inference
    with torch.no_grad():
        output = model(input_tensor)
        prediction = torch.argmax(output, dim=1).squeeze(0).cpu().numpy()
    
    mask = mask.numpy()
    
    # Create error map
    # Correct: Green (0, 255, 0)
    # Incorrect: Red (255, 0, 0)
    # Void (255): Black (0, 0, 0)
    
    H, W = mask.shape
    error_map = np.zeros((H, W, 3), dtype=np.uint8)
    
    valid_mask = (mask != 255)
    correct_mask = (prediction == mask) & valid_mask
    incorrect_mask = (prediction != mask) & valid_mask
    
    error_map[correct_mask] = [0, 255, 0]
    error_map[incorrect_mask] = [255, 0, 0]
    
    # Prepare image for display
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_display = image.permute(1, 2, 0).numpy()
    img_display = std * img_display + mean
    img_display = np.clip(img_display, 0, 1)
    
    # Plot
    fig, ax = plt.subplots(1, 4, figsize=(20, 5))
    
    ax[0].imshow(img_display)
    ax[0].set_title("Input Image")
    ax[0].axis('off')
    
    ax[1].imshow(mask, cmap='jet', vmin=0, vmax=18)
    ax[1].set_title("Ground Truth")
    ax[1].axis('off')

    ax[2].imshow(prediction, cmap='jet', vmin=0, vmax=18)
    ax[2].set_title("Prediction")
    ax[2].axis('off')
    
    ax[3].imshow(error_map)
    ax[3].set_title("Error Map (Green=Correct, Red=Wrong)")
    ax[3].axis('off')
    
    plt.show()

for idx in indices:
    print(f"Visualizing error for sample {idx}")
    visualize_error(val_dataset, idx, model, device)

## Class Labels

Print the class labels for reference.

In [None]:
print("Cityscapes Class Labels:")
for id, name in CITYSCAPES_CLASSES.items():
    print(f"{id}: {name}")

## Training Loss

Graph the training loss extracted from logs.

In [None]:
raw_logs = """
tarting training on device:  cuda
Class weights computed. Starting training...
Epoch 0, iteration 0, loss = 3.3058016300201416
Epoch 0, iteration 10, loss = 2.7614519596099854
Epoch 0, iteration 20, loss = 2.7438251972198486
Epoch 0, iteration 30, loss = 2.5187456607818604
Epoch 0, iteration 40, loss = 2.508305311203003
Epoch 0, iteration 50, loss = 2.2867414951324463
Epoch 0, iteration 60, loss = 2.4361629486083984
Epoch 0, iteration 70, loss = 1.9577289819717407
Epoch 0, iteration 80, loss = 1.8566584587097168
Epoch 0, iteration 90, loss = 1.7572576999664307
Epoch 0 done
[Acc] Got 32644602 / 65536000 correct (49.81)
Epoch 1, iteration 0, loss = 1.8496345281600952
Epoch 1, iteration 10, loss = 1.853102684020996
Epoch 1, iteration 20, loss = 1.6164289712905884
Epoch 1, iteration 30, loss = 1.6331778764724731
Epoch 1, iteration 40, loss = 1.6161593198776245
Epoch 1, iteration 50, loss = 1.4898713827133179
Epoch 1, iteration 60, loss = 1.5986688137054443
Epoch 1, iteration 70, loss = 1.4365719556808472
Epoch 1, iteration 80, loss = 1.4092587232589722
Epoch 1, iteration 90, loss = 1.3454973697662354
Epoch 1 done
[Acc] Got 40264432 / 65536000 correct (61.44)
Epoch 2, iteration 0, loss = 1.2678680419921875
Epoch 2, iteration 10, loss = 1.5047433376312256
Epoch 2, iteration 20, loss = 1.2447785139083862
Epoch 2, iteration 30, loss = 1.1710916757583618
Epoch 2, iteration 40, loss = 1.1267248392105103
Epoch 2, iteration 50, loss = 1.1480406522750854
Epoch 2, iteration 60, loss = 1.1601018905639648
Epoch 2, iteration 70, loss = 1.1152286529541016
Epoch 2, iteration 80, loss = 1.1860275268554688
Epoch 2, iteration 90, loss = 1.128713607788086
Epoch 2 done
[Acc] Got 44228638 / 65536000 correct (67.49)
Epoch 3, iteration 0, loss = 0.8769785165786743
Epoch 3, iteration 10, loss = 0.8874343633651733
Epoch 3, iteration 20, loss = 1.0602900981903076
Epoch 3, iteration 30, loss = 0.9333657622337341
Epoch 3, iteration 40, loss = 0.8524754047393799
Epoch 3, iteration 50, loss = 0.8480112552642822
Epoch 3, iteration 60, loss = 0.9142754673957825
Epoch 3, iteration 70, loss = 0.9016153812408447
Epoch 3, iteration 80, loss = 0.9379053115844727
Epoch 3, iteration 90, loss = 0.917534351348877
Epoch 3 done
[Acc] Got 44766967 / 65536000 correct (68.31)
Epoch 4, iteration 0, loss = 0.7357617616653442
Epoch 4, iteration 10, loss = 0.792691707611084
Epoch 4, iteration 20, loss = 0.780913770198822
Epoch 4, iteration 30, loss = 0.7233166694641113
Epoch 4, iteration 40, loss = 0.6928702592849731
Epoch 4, iteration 50, loss = 0.7370774149894714
Epoch 4, iteration 60, loss = 0.6497021913528442
Epoch 4, iteration 70, loss = 0.6308227777481079
Epoch 4, iteration 80, loss = 0.6553708910942078
Epoch 4, iteration 90, loss = 0.6834204792976379
Epoch 4 done
[Acc] Got 48340190 / 65536000 correct (73.76)
Epoch 5, iteration 0, loss = 0.6047251224517822
Epoch 5, iteration 10, loss = 0.6590820550918579
Epoch 5, iteration 20, loss = 0.6733068227767944
Epoch 5, iteration 30, loss = 0.5782275795936584
Epoch 5, iteration 40, loss = 0.5867722630500793
Epoch 5, iteration 50, loss = 0.5336177945137024
Epoch 5, iteration 60, loss = 0.6968402862548828
Epoch 5, iteration 70, loss = 0.632849931716919
Epoch 5, iteration 80, loss = 0.5604737401008606
Epoch 5, iteration 90, loss = 0.6173032522201538
Epoch 5 done
[Acc] Got 48585277 / 65536000 correct (74.14)
Epoch 6, iteration 0, loss = 0.5196853876113892
Epoch 6, iteration 10, loss = 0.5086436867713928
Epoch 6, iteration 20, loss = 0.5854882597923279
Epoch 6, iteration 30, loss = 0.6928310394287109
Epoch 6, iteration 40, loss = 0.5328149795532227
Epoch 6, iteration 50, loss = 0.41671279072761536
Epoch 6, iteration 60, loss = 0.5883469581604004
Epoch 6, iteration 70, loss = 0.6042952537536621
Epoch 6, iteration 80, loss = 0.5938699245452881
Epoch 6, iteration 90, loss = 0.5658499598503113
Epoch 6 done
[Acc] Got 49459192 / 65536000 correct (75.47)
Epoch 7, iteration 0, loss = 0.4994562566280365
Epoch 7, iteration 10, loss = 0.42685115337371826
Epoch 7, iteration 20, loss = 0.41423314809799194
Epoch 7, iteration 30, loss = 0.4584282636642456
Epoch 7, iteration 40, loss = 0.406289666891098
Epoch 7, iteration 50, loss = 0.5013749003410339
Epoch 7, iteration 60, loss = 0.5198830962181091
Epoch 7, iteration 70, loss = 0.4176245927810669
Epoch 7, iteration 80, loss = 0.4864674210548401
Epoch 7, iteration 90, loss = 0.4675625264644623
Epoch 7 done
[Acc] Got 49507686 / 65536000 correct (75.54)
Epoch 8, iteration 0, loss = 0.4036969542503357
Epoch 8, iteration 10, loss = 0.4940637946128845
Epoch 8, iteration 20, loss = 0.4088168442249298
Epoch 8, iteration 30, loss = 0.4945453703403473
Epoch 8, iteration 40, loss = 0.4359357953071594
Epoch 8, iteration 50, loss = 0.48374372720718384
Epoch 8, iteration 60, loss = 0.4908166527748108
Epoch 8, iteration 70, loss = 0.4532809257507324
Epoch 8, iteration 80, loss = 0.36817148327827454
Epoch 8, iteration 90, loss = 0.3807017505168915
Epoch 8 done
[Acc] Got 50222162 / 65536000 correct (76.63)
Epoch 9, iteration 0, loss = 0.46588587760925293
Epoch 9, iteration 10, loss = 0.3540979325771332
Epoch 9, iteration 20, loss = 0.3933388888835907
Epoch 9, iteration 30, loss = 0.37464767694473267
Epoch 9, iteration 40, loss = 0.32946231961250305
Epoch 9, iteration 50, loss = 0.4122910797595978
Epoch 9, iteration 60, loss = 0.3129812777042389
Epoch 9, iteration 70, loss = 0.3862900733947754
Epoch 9, iteration 80, loss = 0.3184240758419037
Epoch 9, iteration 90, loss = 0.3778306245803833
Epoch 9 done
[Acc] Got 49275314 / 65536000 correct (75.19)
Epoch 10, iteration 0, loss = 0.2907436490058899
Epoch 10, iteration 10, loss = 0.36047449707984924
Epoch 10, iteration 20, loss = 0.29311102628707886
Epoch 10, iteration 30, loss = 0.36856725811958313
Epoch 10, iteration 40, loss = 0.3548986613750458
Epoch 10, iteration 50, loss = 0.32243484258651733
Epoch 10, iteration 60, loss = 0.3436042070388794
Epoch 10, iteration 70, loss = 0.3320477306842804
Epoch 10, iteration 80, loss = 0.3585873544216156
Epoch 10, iteration 90, loss = 0.40507972240448
Epoch 10 done
[Acc] Got 50327525 / 65536000 correct (76.79)
Epoch 11, iteration 0, loss = 0.30078133940696716
Epoch 11, iteration 10, loss = 0.31642091274261475
Epoch 11, iteration 20, loss = 0.2745259404182434
Epoch 11, iteration 30, loss = 0.2700251042842865
Epoch 11, iteration 40, loss = 0.3196711540222168
Epoch 11, iteration 50, loss = 0.2761358320713043
Epoch 11, iteration 60, loss = 0.2873099744319916
Epoch 11, iteration 70, loss = 0.2964515686035156
Epoch 11, iteration 80, loss = 0.31377366185188293
Epoch 11, iteration 90, loss = 0.2647247612476349
Epoch 11 done
[Acc] Got 50534094 / 65536000 correct (77.11)
Epoch 12, iteration 0, loss = 0.2631071209907532
Epoch 12, iteration 10, loss = 0.3709346652030945
Epoch 12, iteration 20, loss = 0.2663397490978241
Epoch 12, iteration 30, loss = 0.28251898288726807
Epoch 12, iteration 40, loss = 0.3104124963283539
Epoch 12, iteration 50, loss = 0.2776812016963959
Epoch 12, iteration 60, loss = 0.2807433009147644
Epoch 12, iteration 70, loss = 0.2893621623516083
Epoch 12, iteration 80, loss = 0.3566482663154602
Epoch 12, iteration 90, loss = 0.29072433710098267
Epoch 12 done
[Acc] Got 50548027 / 65536000 correct (77.13)
Epoch 13, iteration 0, loss = 0.29906585812568665
Epoch 13, iteration 10, loss = 0.28416478633880615
Epoch 13, iteration 20, loss = 0.3200085759162903
Epoch 13, iteration 30, loss = 0.30321961641311646
Epoch 13, iteration 40, loss = 0.28119754791259766
Epoch 13, iteration 50, loss = 0.3090986907482147
Epoch 13, iteration 60, loss = 0.25297197699546814
Epoch 13, iteration 70, loss = 0.28880131244659424
Epoch 13, iteration 80, loss = 0.28377172350883484
Epoch 13, iteration 90, loss = 0.3081030249595642
Epoch 13 done
[Acc] Got 50611871 / 65536000 correct (77.23)
Epoch 14, iteration 0, loss = 0.2777388095855713
Epoch 14, iteration 10, loss = 0.2680618166923523
Epoch 14, iteration 20, loss = 0.317973256111145
Epoch 14, iteration 30, loss = 0.2640048563480377
Epoch 14, iteration 40, loss = 0.24710425734519958
Epoch 14, iteration 50, loss = 0.29299378395080566
Epoch 14, iteration 60, loss = 0.2660341262817383
Epoch 14, iteration 70, loss = 0.31922027468681335
Epoch 14, iteration 80, loss = 0.3028830289840698
Epoch 14, iteration 90, loss = 0.2589528560638428
Epoch 14 done
[Acc] Got 50693896 / 65536000 correct (77.35)
Epoch 15, iteration 0, loss = 0.28195223212242126
Epoch 15, iteration 10, loss = 0.2804413437843323
Epoch 15, iteration 20, loss = 0.2517385482788086
Epoch 15, iteration 30, loss = 0.2311534732580185
Epoch 15, iteration 40, loss = 0.2568943202495575
Epoch 15, iteration 50, loss = 0.2699096202850342
Epoch 15, iteration 60, loss = 0.26955634355545044
Epoch 15, iteration 70, loss = 0.27648988366127014
Epoch 15, iteration 80, loss = 0.2281632274389267
Epoch 15, iteration 90, loss = 0.2567914128303528
Epoch 15 done
[Acc] Got 50736382 / 65536000 correct (77.42)
Epoch 16, iteration 0, loss = 0.30966833233833313
Epoch 16, iteration 10, loss = 0.2716463506221771
Epoch 16, iteration 20, loss = 0.23762747645378113
Epoch 16, iteration 30, loss = 0.24094948172569275
Epoch 16, iteration 40, loss = 0.2996186912059784
Epoch 16, iteration 50, loss = 0.2868756055831909
Epoch 16, iteration 60, loss = 0.25265511870384216
Epoch 16, iteration 70, loss = 0.2533949613571167
Epoch 16, iteration 80, loss = 0.24487940967082977
Epoch 16, iteration 90, loss = 0.27977606654167175
Epoch 16 done
[Acc] Got 50788088 / 65536000 correct (77.50)
Epoch 17, iteration 0, loss = 0.2606678009033203
Epoch 17, iteration 10, loss = 0.23847146332263947
Epoch 17, iteration 20, loss = 0.2461446076631546
Epoch 17, iteration 30, loss = 0.2647072672843933
Epoch 17, iteration 40, loss = 0.2724859416484833
Epoch 17, iteration 50, loss = 0.27458974719047546
Epoch 17, iteration 60, loss = 0.23789352178573608
Epoch 17, iteration 70, loss = 0.23401415348052979
Epoch 17, iteration 80, loss = 0.25124919414520264
Epoch 17, iteration 90, loss = 0.22843174636363983
Epoch 17 done
[Acc] Got 50829379 / 65536000 correct (77.56)
Epoch 18, iteration 0, loss = 0.2592020034790039
Epoch 18, iteration 10, loss = 0.26451829075813293
Epoch 18, iteration 20, loss = 0.23266899585723877
Epoch 18, iteration 30, loss = 0.2652112543582916
Epoch 18, iteration 40, loss = 0.31155991554260254
Epoch 18, iteration 50, loss = 0.24660316109657288
Epoch 18, iteration 60, loss = 0.2530708611011505
Epoch 18, iteration 70, loss = 0.26066720485687256
Epoch 18, iteration 80, loss = 0.2664709985256195
Epoch 18, iteration 90, loss = 0.24483118951320648
Epoch 18 done
[Acc] Got 50852476 / 65536000 correct (77.59)
Epoch 19, iteration 0, loss = 0.28602084517478943
Epoch 19, iteration 10, loss = 0.2499474436044693
Epoch 19, iteration 20, loss = 0.23145028948783875
Epoch 19, iteration 30, loss = 0.2147887945175171
Epoch 19, iteration 40, loss = 0.25263291597366333
Epoch 19, iteration 50, loss = 0.24309957027435303
Epoch 19, iteration 60, loss = 0.2003411054611206
Epoch 19, iteration 70, loss = 0.24445290863513947
Epoch 19, iteration 80, loss = 0.26503413915634155
Epoch 19, iteration 90, loss = 0.2178436666727066
Epoch 19 done
[Acc] Got 50911802 / 65536000 correct (77.69)
Epoch 20, iteration 0, loss = 0.2739868462085724
Epoch 20, iteration 10, loss = 0.17937368154525757
Epoch 20, iteration 20, loss = 0.2647779583930969
Epoch 20, iteration 30, loss = 0.27983516454696655
Epoch 20, iteration 40, loss = 0.22576594352722168
Epoch 20, iteration 50, loss = 0.23262038826942444
Epoch 20, iteration 60, loss = 0.24332231283187866
Epoch 20, iteration 70, loss = 0.23796558380126953
Epoch 20, iteration 80, loss = 0.22480015456676483
Epoch 20, iteration 90, loss = 0.2157441885546875
Epoch 20 done
[Acc] Got 50981405 / 65536000 correct (77.79)
Epoch 21, iteration 0, loss = 0.23990315198898315
Epoch 21, iteration 10, loss = 0.26503807306289673
Epoch 21, iteration 20, loss = 0.23275834321975708
Epoch 21, iteration 30, loss = 0.2418297976255417
Epoch 21, iteration 40, loss = 0.24351932108402252
Epoch 21, iteration 50, loss = 0.27832695841789246
Epoch 21, iteration 60, loss = 0.23205767571926117
Epoch 21, iteration 70, loss = 0.2113640308380127
Epoch 21, iteration 80, loss = 0.25845709443092346
Epoch 21, iteration 90, loss = 0.23790058493614197
Epoch 21 done
[Acc] Got 50925452 / 65536000 correct (77.71)
Epoch 22, iteration 0, loss = 0.23740394413471222
Epoch 22, iteration 10, loss = 0.24445924162864685
Epoch 22, iteration 20, loss = 0.2687460780143738
Epoch 22, iteration 30, loss = 0.20516453683376312
Epoch 22, iteration 40, loss = 0.20121559500694275
Epoch 22, iteration 50, loss = 0.22019052505493164
Epoch 22, iteration 60, loss = 0.23598328232765198
Epoch 22, iteration 70, loss = 0.18659521639347076
Epoch 22, iteration 80, loss = 0.23251952230930328
Epoch 22, iteration 90, loss = 0.21074414253234863
Epoch 22 done
[Acc] Got 50897579 / 65536000 correct (77.66)
Epoch 23, iteration 0, loss = 0.21150906383991241
Epoch 23, iteration 10, loss = 0.17958973348140717
Epoch 23, iteration 20, loss = 0.256083220243454
Epoch 23, iteration 30, loss = 0.22666116058826447
Epoch 23, iteration 40, loss = 0.2972954213619232
Epoch 23, iteration 50, loss = 0.23105847835540771
Epoch 23, iteration 60, loss = 0.21751855313777924
Epoch 23, iteration 70, loss = 0.23931296169757843
Epoch 23, iteration 80, loss = 0.2350941300392151
Epoch 23, iteration 90, loss = 0.26259616017341614
Epoch 23 done
[Acc] Got 50981948 / 65536000 correct (77.79)
Epoch 24, iteration 0, loss = 0.22478987276554108
Epoch 24, iteration 10, loss = 0.2334124743938446
Epoch 24, iteration 20, loss = 0.20308785140514374
Epoch 24, iteration 30, loss = 0.2234630137681961
Epoch 24, iteration 40, loss = 0.23612076044082642
Epoch 24, iteration 50, loss = 0.25721409916877747
Epoch 24, iteration 60, loss = 0.2700474262237549
Epoch 24, iteration 70, loss = 0.21128010749816895
Epoch 24, iteration 80, loss = 0.20631177723407745
Epoch 24, iteration 90, loss = 0.2860499620437622
Epoch 24 done
"""

import re

# Parse losses
losses = []
for line in raw_logs.split('\n'):
    match = re.search(r"loss = ([\d\.]+)", line)
    if match:
        losses.append(float(match.group(1)))

# Plot
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.grid(True)
plt.show()