In [1]:
import plotly.graph_objects as go
import sys
import os
from pathlib import Path
import numpy as np
import random
import math
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

## <ins>Exploratory Data Analysis - Object File format for 3D Shapes</ins>

### 1. Identifying the Classes and moving into Hashmap

In [2]:
path = Path("/scratch/rpushpar/Datasets/ModelNet40/raw")

classes = {i:folder for i, folder in enumerate(sorted(os.listdir(path)))}
classes


{0: 'airplane',
 1: 'bathtub',
 2: 'bed',
 3: 'bench',
 4: 'bookshelf',
 5: 'bottle',
 6: 'bowl',
 7: 'car',
 8: 'chair',
 9: 'cone',
 10: 'cup',
 11: 'curtain',
 12: 'desk',
 13: 'door',
 14: 'dresser',
 15: 'flower_pot',
 16: 'glass_box',
 17: 'guitar',
 18: 'keyboard',
 19: 'lamp',
 20: 'laptop',
 21: 'mantel',
 22: 'monitor',
 23: 'night_stand',
 24: 'person',
 25: 'piano',
 26: 'plant',
 27: 'radio',
 28: 'range_hood',
 29: 'sink',
 30: 'sofa',
 31: 'stairs',
 32: 'stool',
 33: 'table',
 34: 'tent',
 35: 'toilet',
 36: 'tv_stand',
 37: 'vase',
 38: 'wardrobe',
 39: 'xbox'}

### 2. Exploring a sample datapoints (.OFF file) and extracting it's Vertices and Faces

In [3]:
def OFF_data_representer(file):
    # Reading the first line of the OFF file
    OFF_header = file.readline().strip()

    # Checking if the above read content is OFF indicating an Object File Format file
    if OFF_header == "OFF":
        # If yes, then extract the vertices and faces count listed in the second line
        n_vertices, n_faces, _ = [int(s) for s in file.readline().strip().split(' ')]
    else:
        # Elseif the info is along the first line, get the vertices and faces count
        n_vertices, n_faces, _ = [int(s) for s in OFF_header[3:].strip().split(' ')]

    
    # In a loop for total number of vertices, get the vertices as a List[List]
    vertices = [[float(v) for v in file.readline().strip().split(' ')] for i_vertice in range(n_vertices)]

    # In a loop for total number of faces, get the faces as a List[List]
    faces = [[int(f) for f in file.readline()[1:].strip().split(' ')] for i_faces in range(n_faces)]

    return vertices, faces

In [4]:
# Open a sample file to extract the vertices and faces
with open(path/"airplane/train/airplane_0001.off", "r") as f:
    vertices, faces = OFF_data_representer(f)

# Display the first 5 vertices and faces
vertices[0:5], faces[0:5]

([[20.967, -26.1154, 46.5444],
  [21.0619, -26.091, 46.5031],
  [-83.1524, -52.8062, 91.8328],
  [20.967, -26.1154, 46.5444],
  [0.572407, -48.35, 93.2093]],
 [[24, 25, 26], [25, 24, 27], [28, 29, 30], [31, 30, 29], [27, 32, 25]])

### 3. Displaying the sample 3D shape 

#### 3.1. 3D Mesh Display

In [5]:
# Making the vertices and faces as numpy array's to use in Plotly function
i, j, k = np.array(faces).T
x, y, z = np.array(vertices).T

# Using Mesh3d from Plotly, we take each vector in space and its angles and plot them
data = go.Mesh3d(x=x, y=y, z=z, i=i, j=j, k=k, color='lightpink', opacity=0.5)
fig = go.Figure(data)
#fig.show()

#### 3.2. 3D Scatter Display

In [6]:
# Using Scatter3d from Plotly, we take each vector in space and its angles and plot them
data = go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=1, color=z, colorscale='Viridis', opacity=0.5))
fig = go.Figure(data)
#fig.show()

## <ins>Data Conversion - Vertices and Faces to Point Cloud Data</ins>

### 1. Point Sampler

**Point Sampler** is mainly used to **convert mesh data (vertices and faces) into a point cloud** by sampling points inside the triangles that make up the mesh. This is done by generating points using **barycentric coordinates**. 
**Downsampling** can also be an aspect of Point Sampling, where the number of points is reduced, but the goal is to preserve the shape and geometry of the 3D object in the new point cloud representation. It ensures that the **3D shape is still intact** by sampling points **inside the faces**, thus maintaining a proper representation of the object for tasks such as 3D classification or segmentation.

Currently, we have only the vertices(cordinates of each point in space (*x,y,z*) and the faces which provides information about which 3 vertices forming a triangle when arranged would provide the shape of the corresponding 3D object. Given this, we aim to construct the point cloud data for each object by mathematically modelling it. To do so, we folow the below steps:

**Triangle Area** : Calculates the area of a triangle defined by three points (pt1, pt2, pt3) in 3D space.

1. First, we calculate the distance between each pair of vertices (side_a, side_b, side_c) using np.linalg.norm, which computes the Euclidean distance between two points.
2. Then it calculates the semi-perimeter (s) and uses Heron’s formula to compute the area of the triangle.
3. This area is useful in weighted sampling, where larger triangles will have a higher probability of being sampled

**Sample Point** : Samples a random point inside a triangle defined by pt1, pt2, and pt3.

1. Two random numbers (s, t) are generated and sorted to determine the relative positions of the point within the triangle.
2. The lambda function f(i) computes the coordinates of the sampled point based on the barycentric interpolation formula, which ensures the sampled point lies inside the triangle.

The *__call__* function receives the vertices and faces as arguments. Moving forward: 
- The method first computer the area of each triangle
- It samples the triangle based on the area using random.choices (which selects traingles with higher area more frequently)
- After selecting the triangle the corresponsing triangle vertices are passed into Sample Point Method, resulting the sample points.
- The sample points are stored in a sample_point numpy array with shape (output_size, 3)


In [7]:
class PointSampler(object):
    def __init__(self, output_size):
        assert isinstance(output_size, int)
        self.output_size = output_size
    
    def triangle_area(self, pt1, pt2, pt3):
        side_a = np.linalg.norm(pt1 - pt2)
        side_b = np.linalg.norm(pt2 - pt3)
        side_c = np.linalg.norm(pt3 - pt1)
        s = 0.5 * ( side_a + side_b + side_c)
        return max(s * (s - side_a) * (s - side_b) * (s - side_c), 0)**0.5

    def sample_point(self, pt1, pt2, pt3):
        # barycentric coordinates on a triangle
        # https://mathworld.wolfram.com/BarycentricCoordinates.html

        ''' 
            In a triangle with vertices pt1, pt2 and pt3, a point P is called Barycentric Coordinate such that:
                                            P = λ1.pt1 + λ2.pt2 + λ3.pt3     , where λ1 + λ2 + λ3 = 1
            
            Now, consider a triangle with the following vertices,
                                            pt1=(0,0,0)
                                            pt2=(1,0,0)
                                            pt3=(0,1,0)
            
            Let's say the random values s=0.2 and t=0.6 such that 0 < s < 1 and 0 < t < 1,

            The barycentric coordinates would be:
                                            λ1 = s = 0.2
                                            λ2 = (s-t) = 0.2
                                            λ3 = (1-t) = 0.6                 ,where λ1 = s, λ2 = (t-s), λ3 = (1-t) such that λ1 + λ2 + λ3 = 1

            Using the lambda function to compute the point:
            
                                x=s.x1 + (t-s).x2 + (1-t).x3
                                x=s⋅0  + (t−s)⋅1   + (1−t)⋅0  =0.2⋅0+(0.6−0.2)⋅1+(1−0.6)⋅0=0.4
                                
                                y=s.y1 + (s-t).y2 + (1-t).y3
                                y=s⋅0   +(t−s)⋅0   +(1−t)⋅1   =0.2⋅0+(0.6−0.2)⋅0+(1−0.6)⋅1=0.4

                                z=s.z1 + (t-s).z2 + (1-t).z3
                                z=s⋅0   +(t−s)⋅0   +(1−t)⋅0   =0

            Thus, the random point inside the triangle is (0.4, 0.4, 0)
        '''

        
        # Generate two random float's between 0 and 1 (excluding 1)
        s, t = sorted([random.random(), random.random()])

        
        f = lambda i: s * pt1[i] + (t-s)*pt2[i] + (1-t)*pt3[i]
        return (f(0), f(1), f(2))
        
    
    def __call__(self, mesh):
        vertices, faces = mesh
        vertices = np.array(vertices)
        areas = np.zeros((len(faces)))

        for i in range(len(areas)):
            areas[i] = (self.triangle_area(vertices[faces[i][0]],
                                           vertices[faces[i][1]],
                                           vertices[faces[i][2]]))
            
        sampled_faces = (random.choices(faces, 
                                      weights=areas,
                                      cum_weights=None,
                                      k=self.output_size))
        
        sampled_points = np.zeros((self.output_size, 3))

        for i in range(len(sampled_faces)):
            sampled_points[i] = (self.sample_point(vertices[sampled_faces[i][0]],
                                                   vertices[sampled_faces[i][1]],
                                                   vertices[sampled_faces[i][2]]))
        
        return sampled_points

In [8]:
pointcloud = PointSampler(4000)((vertices, faces))
x, y, z = pointcloud.T
data = go.Scatter3d(x=x, y=y, z=z, mode='markers', marker=dict(size=1, color=z, colorscale='Viridis', opacity=0.5))
fig= go.Figure(data)
#fig.show()

*Inference*:

Thus, The **PointSampler class** is **not performing random reductions** of an existing point cloud, but instead it is **sampling points** from the **mesh's faces** to generate **a new point cloud**.
The downsampling helps **reduce the complexity of the data** (e.g., from a mesh with a large number of faces to a fixed number of points) while still keeping the surface of the object intact for further processing (such as training a machine learning model).

## <ins>Data Augmentation</ins>

### 1. Normalize Data

In [9]:
class Normalize(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2
        
        norm_pointcloud = pointcloud - np.mean(pointcloud, axis=0) 
        norm_pointcloud /= np.max(np.linalg.norm(norm_pointcloud, axis=1))

        return  norm_pointcloud

### 2. Random Rotation along Z-axis

In [10]:
class RandRotation_z(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2

        theta = random.random() * 2. * math.pi
        rot_matrix = np.array([[ math.cos(theta), -math.sin(theta),    0],
                               [ math.sin(theta),  math.cos(theta),    0],
                               [0,                             0,      1]])
        
        rot_pointcloud = rot_matrix.dot(pointcloud.T).T
        return  rot_pointcloud

### 3. Random Noise

In [11]:
class RandomNoise(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2

        noise = np.random.normal(0, 0.02, (pointcloud.shape))
    
        noisy_pointcloud = pointcloud + noise
        return  noisy_pointcloud

### 4. Convert the Point Augmented Data Point Cloud to Tensor

In [12]:
class ToTensor(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2

        return torch.from_numpy(pointcloud)

## <ins>Dataset Preparation</ins>

### 1. Data Transform Method's

In [13]:
# Selected except Training Dataset (Validation and Testing)
def default_transforms():
    return transforms.Compose([
        PointSampler(1024),
        Normalize(),
        ToTensor()])

# Selected only for Training Dataset
def train_transforms():
    return transforms.Compose([
        PointSampler(1024),
        Normalize(),
        RandRotation_z(),
        RandomNoise(),
        ToTensor()])

### 2. Data Preprocessing Method

In [14]:
class PointCloudData(Dataset):
    def __init__(self, root_dir, valid=False, folder="train", transform=default_transforms()):
        self.root_dir = root_dir
        folders = [dir for dir in sorted(os.listdir(root_dir)) if os.path.isdir(root_dir/dir)]
        self.classes = {folder: i for i, folder in enumerate(folders)}
        self.transforms = transform if not valid else default_transforms()
        self.valid = valid
        self.files = []
        for category in self.classes.keys():
            new_dir = root_dir/Path(category)/folder
            for file in os.listdir(new_dir):
                if file.endswith('.off'):
                    sample = {}
                    sample['pcd_path'] = new_dir/file
                    sample['category'] = category
                    self.files.append(sample)

    def __len__(self):
        return len(self.files)

    def __preproc__(self, file):
        verts, faces = OFF_data_representer(file)
        if self.transforms:
            pointcloud = self.transforms((verts, faces))
        return pointcloud

    def __getitem__(self, idx):
        pcd_path = self.files[idx]['pcd_path']
        category = self.files[idx]['category']
        with open(pcd_path, 'r') as f:
            pointcloud = self.__preproc__(f)
        return {'pointcloud': pointcloud, 
                'category': self.classes[category]}

### 3. Dataset definition

In [15]:
train_ds = PointCloudData(path, transform=train_transforms())
valid_ds = PointCloudData(path, valid=True, folder='test', transform=default_transforms())

In [16]:
print('Train dataset size: ', len(train_ds))
print('Valid dataset size: ', len(valid_ds))
print('Number of classes: ', len(train_ds.classes))
print('Sample pointcloud shape: ', train_ds[0]['pointcloud'].size())
print('Class: ', classes[train_ds[0]['category']])

Train dataset size:  9843
Valid dataset size:  2468
Number of classes:  40
Sample pointcloud shape:  torch.Size([1024, 3])
Class:  airplane


In [17]:
train_loader = DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset=valid_ds, batch_size=64)

## <ins>Model Definition</ins>

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PointNetSetAbstraction(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all):
        super(PointNetSetAbstraction, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all

    def forward(self, xyz, points=None):
        # if self.npoint is None:  
        #     new_xyz = None
        #     new_points = points.unsqueeze(2) if points is not None else xyz.unsqueeze(2)
        #     for i, conv in enumerate(self.mlp_convs):
        #         bn = self.mlp_bns[i]
        #         new_points = F.relu(bn(conv(new_points)))
        #     new_points = torch.max(new_points, -1)[0]
        # else:
        #     # Sample and group points
        #     new_xyz, new_points = self.sample_and_group(xyz, points)
        #     new_points = new_points.permute(0, 3, 1, 2)  # (B, N, C, K) -> (B, C, N, K)
        #     for i, conv in enumerate(self.mlp_convs):
        #         bn = self.mlp_bns[i]
        #         new_points = F.relu(bn(conv(new_points)))
        #     new_points = torch.max(new_points, -1)[0]  # Max pooling
        # return new_xyz, new_points
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = self.sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points = self.sample_and_group(self.npoint, self.radius, self.nsample, xyz, points)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))

        new_points = torch.max(new_points, 2)[0]
        new_xyz = new_xyz.permute(0, 2, 1)
        return new_xyz, new_points       

    def sample_and_group(self, npoint, radius, nsample, xyz, points, returnfps=False):
        B, N, C = xyz.shape
        S = self.npoint
        fps_idx = self.farthest_point_sample(xyz, S)
        new_xyz = self.index_points(xyz, fps_idx)
        idx = self.ball_query(self.radius, self.nsample, xyz, new_xyz)
        grouped_xyz = self.index_points(xyz, idx)
        grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
        if points is not None:
            grouped_points = self.index_points(points, idx)
            new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1)
        else:
            new_points = grouped_xyz_norm
        return new_xyz, new_points

    def sample_and_group_all(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, N, 3]
            points: input points data, [B, N, D]
        Return:
            new_xyz: sampled points position data, [B, 1, 3]
            new_points: sampled points data, [B, 1, N, 3+D]
        """
        device = xyz.device
        B, N, C = xyz.shape
        new_xyz = torch.zeros(B, 1, C).to(device)
        grouped_xyz = xyz.view(B, 1, N, C)
        if points is not None:
            new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
        else:
            new_points = grouped_xyz
        return new_xyz, new_points

    def farthest_point_sample(self, xyz, npoint):
        B, N, C = xyz.shape
        centroids = torch.zeros(B, npoint, dtype=torch.long).to(xyz.device)
        distance = torch.ones(B, N).to(xyz.device) * 1e10
        farthest = torch.randint(0, N, (B,), dtype=torch.long).to(xyz.device)
        batch_indices = torch.arange(B, dtype=torch.long).to(xyz.device)
        for i in range(npoint):
            centroids[:, i] = farthest
            #centroid = xyz[batch_indices, farthest, :].view(B, 1, C)
            centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
            dist = torch.sum((xyz - centroid) ** 2, -1)
            mask = dist < distance
            distance[mask] = dist[mask]
            farthest = torch.max(distance, -1)[1]
        return centroids

    def index_points(self, points, idx):
        B = points.shape[0]
        view_shape = list(idx.shape)
        view_shape[1:] = [1] * (len(view_shape) - 1)
        repeat_shape = list(idx.shape)
        repeat_shape[0] = 1
        batch_indices = torch.arange(B, dtype=torch.long).to(points.device).view(view_shape).repeat(repeat_shape)
        new_points = points[batch_indices, idx, :]
        return new_points

    def ball_query(self, radius, nsample, xyz, new_xyz):
        B, N, C = xyz.shape
        _,S,_ = new_xyz.shape
        sqrdists = self.square_distance(new_xyz, xyz)
        group_idx = torch.arange(N, dtype=torch.long).to(xyz.device).view(1, 1, N).repeat([B, S, 1])
        group_idx[sqrdists > radius ** 2] = N  
        group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
        
        group_first = group_idx[:, :, 0].view(B,S,1).repeat([1, 1, nsample])
        mask = group_idx == N
        group_idx[mask] = group_first[mask]
    
        return group_idx
    


    def square_distance(self, src, dst):
        B, N, _ = src.shape
        _, M, _ = dst.shape
        dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
        dist += torch.sum(src ** 2, -1).view(B, N, 1)
        dist += torch.sum(dst ** 2, -1).view(B, 1, M)
        return dist


class PointNetPlusPlus(nn.Module):
    def __init__(self, num_classes):
        super(PointNetPlusPlus, self).__init__()
        self.sa1 = PointNetSetAbstraction(512, 0.2, 32, 3, [64, 64, 128], group_all=False)
        self.sa2 = PointNetSetAbstraction(128, 0.4, 64, 128 +  3, [128, 128, 256], group_all=False)
        self.sa3 = PointNetSetAbstraction(None, None, None, 256 + 3, [256, 512, 1024], group_all=True)
        self.fc1 = nn.Linear(1024, 512)
        self.bn1 = nn.BatchNorm1d(512)
        self.drop1 = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 256)
        self.bn2 = nn.BatchNorm1d(256)
        self.drop2 = nn.Dropout(0.4)
        self.fc3 = nn.Linear(256, num_classes)
        #self.dropout = nn.Dropout(0.5)

    def forward(self, xyz, points=None):
        # xyz, points = self.sa1(xyz, points)
        # xyz, points = self.sa2(xyz, points)
        # _, points = self.sa3(xyz, points)
        # x = points.view(points.size(0), -1)
        # x = F.relu(self.bn1(self.fc1(x)))
        # x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        # x = self.fc3(x)
        # return F.log_softmax(x, dim=1)
        B, _, _ = xyz.shape
        
        norm = xyz[:, 3:, :]
        xyz = xyz[:, :3, :]
        l1_xyz, l1_points = self.sa1(xyz, norm)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        x = l3_points.view(B, 1024)
        x = self.drop1(F.relu(self.bn1(self.fc1(x))))
        x = self.drop2(F.relu(self.bn2(self.fc2(x))))
        x = self.fc3(x)
        x = F.log_softmax(x, -1)


        return x

In [19]:
def pointnetloss(outputs, labels, m3x3, m64x64, alpha = 0.0001):
    criterion = torch.nn.NLLLoss()
    bs=outputs.size(0)
    id3x3 = torch.eye(3, requires_grad=True).repeat(bs,1,1)
    id64x64 = torch.eye(64, requires_grad=True).repeat(bs,1,1)
    if outputs.is_cuda:
        id3x3=id3x3.cuda()
        id64x64=id64x64.cuda()
    diff3x3 = id3x3-torch.bmm(m3x3,m3x3.transpose(1,2))
    diff64x64 = id64x64-torch.bmm(m64x64,m64x64.transpose(1,2))
    return criterion(outputs, labels) + alpha * (torch.norm(diff3x3)+torch.norm(diff64x64)) / float(bs)

In [20]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [21]:
#print("Input shape before first conv:", points.shape)
pointnet = PointNetPlusPlus(40)
pointnet.to(device);

In [22]:
#optimizer = torch.optim.Adam(pointnet.parameters(), lr=0.0008)

In [23]:
def train(model, train_loader, val_loader=None, epochs=10):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        tqdm_bar = tqdm(enumerate(train_loader, 0), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}")
        
        for i, data in tqdm_bar:
            inputs, labels = data['pointcloud'].to(device).float(), data['category'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs.transpose(1, 2), points=None)
            
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            train_acc = correct / total
            tqdm_bar.set_postfix(loss=running_loss / (i + 1), train_acc=train_acc)
        
        final_train_acc = 100. * correct / total
        final_train_loss = running_loss / len(train_loader)
        
        model.eval()
        correct = total = 0
        val_acc = 0
        
        if val_loader:
            with torch.no_grad():
                for data in val_loader:
                    inputs, labels = data['pointcloud'].to(device).float(), data['category'].to(device)
                    outputs = model(inputs.transpose(1, 2), points=None)
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
                val_acc = 100. * correct / total
        
        print(f'Epoch {epoch+1}/{epochs} | Train Loss: {final_train_loss:.3f} | Train Acc: {final_train_acc:.2f}% | Val Acc: {val_acc:.2f}%')
        
        torch.save(model.state_dict(), "save_pointnetpp.pth")

In [None]:
train(pointnet, train_loader, valid_loader)

Epoch 1/10: 100%|██████████| 308/308 [39:26<00:00,  7.68s/it, loss=1.8, train_acc=0.523]  


Epoch 1/10 | Train Loss: 1.804 | Train Acc: 52.26% | Val Acc: 60.33%


Epoch 2/10: 100%|██████████| 308/308 [44:16<00:00,  8.63s/it, loss=1.07, train_acc=0.685] 


Epoch 2/10 | Train Loss: 1.066 | Train Acc: 68.52% | Val Acc: 72.37%


Epoch 3/10: 100%|██████████| 308/308 [50:47<00:00,  9.89s/it, loss=0.881, train_acc=0.736] 


Epoch 3/10 | Train Loss: 0.881 | Train Acc: 73.65% | Val Acc: 75.77%


Epoch 4/10: 100%|██████████| 308/308 [53:18<00:00, 10.38s/it, loss=0.781, train_acc=0.76]   


Epoch 4/10 | Train Loss: 0.781 | Train Acc: 75.96% | Val Acc: 76.05%


Epoch 5/10: 100%|██████████| 308/308 [52:34<00:00, 10.24s/it, loss=0.7, train_acc=0.781]   


Epoch 5/10 | Train Loss: 0.700 | Train Acc: 78.06% | Val Acc: 78.89%


Epoch 6/10: 100%|██████████| 308/308 [43:45<00:00,  8.52s/it, loss=0.685, train_acc=0.786]


Epoch 6/10 | Train Loss: 0.685 | Train Acc: 78.59% | Val Acc: 79.21%


Epoch 7/10: 100%|██████████| 308/308 [40:45<00:00,  7.94s/it, loss=0.64, train_acc=0.798]  


Epoch 7/10 | Train Loss: 0.640 | Train Acc: 79.84% | Val Acc: 78.12%


Epoch 8/10:  98%|█████████▊| 303/308 [37:15<00:32,  6.53s/it, loss=0.605, train_acc=0.809] 