In [1]:
from helpers import *
device_in_use = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device_in_use)

Using device: cpu


In [2]:
# Define a transform to convert the data to a PyTorch Tensor
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Download the training data
train_dataset = MNIST(root='../mnist_data', train=True, download=True, transform=transform)
test_dataset = MNIST(root='../mnist_data', train=False, download=True, transform=transform)

rotated_mnist_trainset = RotatedMNISTDataset(train_dataset)
rotated_mnist_testset = RotatedMNISTDataset(test_dataset)

trainloader = torch.utils.data.DataLoader(rotated_mnist_trainset, batch_size=128, shuffle=True)
testloader = torch.utils.data.DataLoader(rotated_mnist_testset, batch_size=128, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../mnist_data\MNIST\raw\train-images-idx3-ubyte.gz


100.0%


Extracting ../mnist_data\MNIST\raw\train-images-idx3-ubyte.gz to ../mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ../mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz


100.0%


Extracting ../mnist_data\MNIST\raw\train-labels-idx1-ubyte.gz to ../mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ../mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz


100.0%


Extracting ../mnist_data\MNIST\raw\t10k-images-idx3-ubyte.gz to ../mnist_data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz


100.0%

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ../mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz
Extracting ../mnist_data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ../mnist_data\MNIST\raw






In [3]:
resnet_18 = models.resnet18()

# Change the first convolutional layer to accept 1-channel input
resnet_18.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False)

#get rid of maxpooling
resnet_18.maxpool = nn.Identity()
#translational invariance

# Modify the final layer to output a single value (the predicted rotation angle)
num_features = resnet_18.fc.in_features
resnet_18.fc = nn.Linear(num_features, 3)

# Now move the modified model to the GPU
resnet_18 = resnet_18.to(device_in_use)

optimizer = torch.optim.Adam(resnet_18.parameters(), lr=0.0001)


In [6]:
train_loss = []
test_loss = []

epochs=500
early_stopping=EarlyStopping(patience=7, verbose=True, mode='min')
for epoch in range(epochs):
    resnet_18.train()
    total_loss_train = 0  # Initialize total training loss
    for original_image, rotated_image, rotation_center, rotation_angle in trainloader:
        # Move inputs to the device
        rotated_image = rotated_image.to(device_in_use)

        ground_truth = torch.cat((rotation_center, rotation_angle), dim=1).to(device_in_use)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = resnet_18(rotated_image).squeeze()

        # Calculate the loss
        loss = custom_loss(outputs, ground_truth, (.15,.85))

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        total_loss_train += loss.item()  # Accumulate the training loss

    avg_loss_train = total_loss_train / len(trainloader)  # Compute the average training loss

    train_loss.append(avg_loss_train)

    resnet_18.eval()
    total_loss_test = 0  # Initialize total test loss
    with torch.no_grad():
        for original_image, rotated_image, rotation_center, rotation_angle in testloader:
            # Move inputs to the device
            rotated_image = rotated_image.to(device_in_use)

            ground_truth = torch.cat((rotation_center, rotation_angle), dim=1).to(device_in_use)

            # Forward pass
            outputs = resnet_18(rotated_image).squeeze()

            # Calculate the loss
            loss_test = custom_loss(outputs, ground_truth, (.15,.85))

            total_loss_test += loss_test.item()  # Accumulate the test loss

    avg_loss_test = total_loss_test / len(testloader) 

    test_loss.append(avg_loss_test)

    early_stopping(avg_loss_test)
    
    if early_stopping.early_stop:
        print("Early stopping")
        break

    # Print loss information
    print(f'Epoch [{epoch+1}/{epochs}], TRAIN | Loss: {avg_loss_train}, TEST | Loss: {avg_loss_test}')

torch.save(resnet_18, 'resnet_18_lr_0.0001_custom.pth')

KeyboardInterrupt: 

In [None]:
# Ensure the model is in evaluation mode
resnet_18.eval()

# Initialize an empty list to store data
data = []

# Disable gradient computation for evaluation
with torch.no_grad():
    for original_image, rotated_image, rotation_center, rotation_angle in testloader:
        # Move inputs to the device
        rotated_image = rotated_image.to(device_in_use)
        
        # Forward pass to get the output from the model
        outputs = resnet_18(rotated_image).cpu()
        
        # Extract the predicted rotation centers and angles from the outputs
        predicted_rotation_center = outputs[:, :2]  # Assuming the first two values are the center
        predicted_rotation_angle = outputs[:, 2]   # Assuming the third value is the angle
        
        # Convert to numpy arrays for easier handling
        rotation_center_np = rotation_center.numpy()
        rotation_angle_np = rotation_angle.numpy().reshape(-1, 1)  # Reshape for concatenation
        predicted_rotation_center_np = predicted_rotation_center.numpy()
        predicted_rotation_angle_np = predicted_rotation_angle.numpy().reshape(-1, 1)
        
        # Iterate over the batch and append each item to the data list
        for i in range(len(rotated_image)):
            data.append([
                rotation_center_np[i][0], rotation_center_np[i][1],  # Ground truth center x, y
                rotation_angle_np[i][0],                            # Ground truth angle
                predicted_rotation_center_np[i][0], predicted_rotation_center_np[i][1],  # Predicted center x, y
                predicted_rotation_angle_np[i][0]                   # Predicted angle
            ])

# Create a DataFrame from the accumulated data
columns = ['Center_X_True', 'Center_Y_True', 'Angle_True', 'Center_X_Pred', 'Center_Y_Pred', 'Angle_Pred']
df_1 = pd.DataFrame(data, columns=columns)

df_1

In [None]:
def angle_error(row):
    if row['Angle_True']*row['Angle_Pred'] > 0:
        return np.abs(row['Angle_True'] - row['Angle_Pred']) % 360
    else:
        temp = np.abs(row['Angle_True']) + np.abs(row['Angle_Pred'])
        return min(temp, 360-temp)
    

for i in [df_1]:
    i['Error_Angle'] = i.apply(angle_error, axis=1)
    i['Error_Center_Distance'] = np.sqrt((i['Center_X_True'] - i['Center_X_Pred'])**2 + (i['Center_Y_True'] - i['Center_Y_Pred'])**2)
    print(f"Average Angle Error: {np.mean(i['Error_Angle'])}")
    print(f"Average Center Error: {np.mean(i['Error_Center_Distance'])}")

In [None]:
# Get a single batch of data
dataiter = iter(testloader)

# Move the model to evaluation mode
resnet_18.eval()

# Get a single batch of data
original_image, rotated_image, rotation_center, rotation_angle = next(dataiter)

# Select the first image in the batch
original_image = original_image[0]  
rotated_image = rotated_image[0]  # Add batch dimension
actual_angle = rotation_angle[0]
actual_center = rotation_center[0]

rotated_image_batched = rotated_image.unsqueeze(0)
rotated_image_batched = rotated_image_batched.to(device_in_use)

with torch.no_grad():
    outputs = resnet_18(rotated_image_batched).cpu()  
        
    # Extract the predicted rotation centers and angles from the outputs
    predicted_rotation_center = outputs[:, :2]  # Assuming the first two values are the center
    predicted_rotation_angle = outputs[:, 2]   # Assuming the third value is the angle

print(predicted_rotation_angle)
print(predicted_rotation_center)

# Ensure predicted_rotation_center and predicted_rotation_angle are on the CPU
predicted_rotation_center_cpu = predicted_rotation_center.cpu()
predicted_rotation_angle_cpu = predicted_rotation_angle.item()  # .item() already ensures it's a CPU scalar

original_image_cv = tensor_to_cv(original_image)
cv_image = tensor_to_cv(rotated_image)

center = tuple(predicted_rotation_center.flatten().tolist())
# Format each element in the tuple to two decimal places
formatted_tuple = tuple(f"{value:.2f}" for value in center)

# Convert the tuple of strings to a single string representation
formatted_string = ', '.join(formatted_tuple)

# Use the CPU versions of center and angle
registered = rotate_image(cv_image, center, -predicted_rotation_angle_cpu)



plt.figure(figsize=(12, 6))

plt.subplot(1, 3, 1)
plt.imshow(original_image_cv, cmap='gray')
plt.title('Original Image')
plt.axis('off')

plt.subplot(1, 3, 2)
plt.imshow(cv_image, cmap='gray')
plt.title('Transformed Digit')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.imshow(registered, cmap='gray')
plt.title('Restored Digit')
plt.axis('off')



plt.show()

