In [None]:
!pip install torch torchvision pillow pandas scikit-learn plotly

Library

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
import torchvision.models as models
from PIL import Image
import torch.nn as nn
from torchvision.datasets import ImageFolder
import os
from torch.utils.data import DataLoader
import random
import plotly.express as px
import pandas as pd
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
import torch.nn.functional as F
from torchvision.models import resnet18

In [None]:
#Unzip the dataset here
from google.colab import drive
drive.mount('/content/drive')


!unzip /content/drive/MyDrive/50States10K.zip -d /content/data/

Data preparation

In [5]:
def geo_loss(outputs,label,long,lat):
  device = outputs.device

  long = long.to(device)
  lat = lat.to(device)

  long1 = torch.deg2rad(long[outputs])
  long2 = torch.deg2rad(long[label])
  lat1 = torch.deg2rad(lat[outputs])
  lat2 = torch.deg2rad(lat[label])

  a = torch.sin((lat1-lat2)/2)**2 + torch.cos(lat1)*torch.cos(lat2)*torch.sin((long1-long2)/2)**2

  return torch.atan2(torch.sqrt(a),1-torch.sqrt(a))

def find_k_nearest_states(state_longs, state_lats,k=3):
    N = len(state_lats)
    nearest = []

    for i in range(N):

        others = torch.tensor([j for j in range(N)])

        dists = geo_loss(
            torch.full((len(others),), i),
            others,
            state_longs,
            state_lats
        )
        topk = torch.topk(dists, k+1, largest=False).indices
        nearest.append(others[topk])


    return nearest

In [6]:
class AveragedGroupDataset(torch.utils.data.Dataset):

    def __init__(self, root, transform=None,num_class=50):
        self.base_dataset = ImageFolder(root=root)
        self.angle = ["0","90","180","270"]
        self.transform = transform


        self.groups = dict()
        for idx, (path, label) in enumerate(self.base_dataset.samples):

            prefix = path.rsplit('_',1)[0]
            if prefix:
                key = (label, prefix)
                if key not in self.groups:
                  flag = True
                  for theta in self.angle:
                    filename = f"{prefix}_{theta}.jpg"
                    if not os.path.exists(filename):
                      flag = False
                      break
                  if flag:
                    self.groups[key] = idx
        self.groups_key = list(self.groups.keys())


    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        label, prefix = self.groups_key[idx]

        images = []
        for theta in self.angle:
            img_path = f"{prefix}_{theta}.jpg"
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            images.append(img)

        return images, label


class SplitedGroupDataset(torch.utils.data.Dataset):
    def __init__(self, base_dataset, groups_key,transform=None):
        self.base_dataset = base_dataset
        self.angle = ["0","90","180","270"]
        self.transform = transform
        self.groups_key = groups_key


    def __len__(self):
        return len(self.groups_key)


    def __getitem__(self, idx):
        label, prefix = self.groups_key[idx]

        images = []
        for theta in self.angle:
            img_path = f"{prefix}_{theta}.jpg"
            img = Image.open(img_path).convert('RGB')
            if self.transform:
                img = self.transform(img)
            images.append(img)

        return images, label

In [7]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

dataset = AveragedGroupDataset('/content/data', transform=transform)
split = dataset.groups_key
random.shuffle(split)
#this n control the size of the dataset we used for the model.
n = len(dataset)//5
split = split[:n]
#this n control the proportion of test and training set
n = 4*n//5
train = split[:n]
test = split[n:]
trainset = SplitedGroupDataset(dataset.base_dataset,train,transform)
testset = SplitedGroupDataset(dataset.base_dataset,test,transform)
dataloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True,num_workers=4)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=True,num_workers=4)

Model setup and training


In [8]:


class MedIntResNet(nn.Module):
    def __init__(self, num_classes=50, num_views=4):
        super(MedIntResNet, self).__init__()
        base_resnet = resnet18(pretrained=True)


        self.shared_conv = nn.Sequential(*list(base_resnet.children())[:-2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))


        self.feature_dim = 512


        self.fc = nn.Linear(self.feature_dim * num_views, num_classes)

    def forward(self, inputs):
        feats = []
        for view in inputs:
            x = self.shared_conv(view)
            x = self.avgpool(x)
            x = x.view(x.size(0), -1)
            feats.append(x)

        combined = torch.cat(feats, dim=1)
        out = self.fc(combined)
        return out

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
long = [
    -86.7911, -152.4044, -111.4312, -92.3731, -119.6816, -105.3111, -72.7554, -75.5071, -81.6868, -83.6431,
    -157.4983, -114.4788, -88.9861, -86.2583, -93.2105, -96.7265, -84.6701, -91.8678, -69.3819, -76.8021,
    -71.5301, -84.5361, -93.9002, -89.6787, -92.2884, -110.4544, -98.2681, -117.0554, -71.5639, -74.521,
    -106.2485, -74.9481, -79.8064, -99.784, -82.7649, -96.9289, -122.0709, -77.2098, -71.5118, -80.945,
    -99.4388, -86.6923, -97.5635, -111.8624, -72.7107, -78.17, -121.4905, -80.9545, -89.6165, -107.3025
]
long = torch.tensor(long)
long.to(device)

lat = [
    32.8067, 61.3707, 33.7298, 34.9697, 36.1162, 39.0598, 41.5978, 39.3185, 27.7663, 33.0406,
    21.0943, 44.2405, 40.3495, 39.8494, 42.0115, 38.5266, 37.6681, 31.1695, 44.6939, 39.0639,
    42.2302, 43.3266, 45.6945, 32.7416, 38.4561, 46.9219, 41.1254, 38.3135, 43.4525, 40.2989,
    34.8405, 42.1657, 35.6301, 47.5289, 40.3888, 35.5653, 44.572, 40.5908, 41.6809, 33.8569,
    44.2998, 35.7478, 31.0545, 40.15, 44.0459, 37.7693, 47.4009, 38.4912, 44.2685, 42.7559
]

lat = torch.tensor(lat)
lat.to(device)
k = 3
nearest = find_k_nearest_states(lat, long, k=3)
nearest = torch.stack(nearest).to(device)


num_classes = 50

Testing

In [None]:





model = MedIntResNet(num_classes=num_classes)

model = model.to(device)


optimizer = optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 50
w = 0.5

for epoch in range(num_epochs):
    model.train()

    for inputs, labels in dataloader:
        inputs = [v.to(device) for v in inputs]
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        pred_class = outputs.argmax(dim=1)
        loss = (1-w)*F.cross_entropy(outputs, labels) + w*geo_loss(pred_class,labels,long,lat).mean()
        loss.backward()
        optimizer.step()


In [27]:


correct = 0
total = 0

all_preds= []
all_labels = []

model.eval()
with torch.no_grad():
    for inputs, labels in testloader:
        inputs = [v.to(device) for v in inputs]
        labels = labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        nearest_k = nearest[labels]
        in_top_k = (nearest_k == predicted.unsqueeze(1)).any(dim=1)
        top_k_accuracy = in_top_k.float().mean().item()
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = correct / total
print(f"Test Accuracy: {accuracy * 100:.2f}%")
print(f"Top {k} Nearest States Accuracy: {top_k_accuracy * 100:.2f}%")

precision = precision_score(all_labels, all_preds, average='macro')
print("Precision:", precision)


recall = recall_score(all_labels, all_preds, average='macro')
f1 = f1_score(all_labels, all_preds, average='macro')
print("Recall:", recall)
print("F-1 score:", f1)

Test Accuracy: 42.18%
Top 3 Nearest States Accuracy: 75.00%
Precision: 0.4139674774969325
Recall: 0.41770981207340463
F-1 score: 0.41089014927622464


Visualization and application

In [20]:
def load_instance_images(base_path, instance_id):
    angle = ['0', '90', '180', '270']
    imgs = []
    for view in angle:
        path = os.path.join(base_path, f'{instance_id}_{view}.jpg')
        img = Image.open(path).convert('RGB')
        img = transform(img)
        imgs.append(img)
    return imgs

In [26]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

us_states = [
    'AL', 'AK', 'AZ', 'AR', 'CA', 'CO', 'CT', 'DE', 'FL', 'GA',
    'HI', 'ID', 'IL', 'IN', 'IA', 'KS', 'KY', 'LA', 'ME', 'MD',
    'MA', 'MI', 'MN', 'MS', 'MO', 'MT', 'NE', 'NV', 'NH', 'NJ',
    'NM', 'NY', 'NC', 'ND', 'OH', 'OK', 'OR', 'PA', 'RI', 'SC',
    'SD', 'TN', 'TX', 'UT', 'VT', 'VA', 'WA', 'WV', 'WI', 'WY'
]

data = {
    'state': us_states,

}

df = pd.DataFrame(data)

#change the test_root to the name of the folder with the testing images
test_root = "/content/new_images"
model.eval()
#change the id to "IMG1" to switch the testing image
instance_id = 'IMG2'
###############################################
imgs = load_instance_images(test_root, instance_id)
imgs = [img.unsqueeze(0).to(device) for img in imgs]
with torch.no_grad():
    output = model(imgs)
    probs = torch.softmax(output, dim=1)
    pred = us_states[probs.argmax()]
    confidence = probs.max(dim=1).values.item()

    probs = probs.squeeze().tolist()
    df['percentage'] = [p*100 for p in probs]
    fig = px.choropleth(
    df,
    locations='state',
    locationmode="USA-states",
    color='percentage',
    scope="usa",
    color_continuous_scale='Blues',
    labels={'percentage': 'Percentage'}
    )

    fig.update_layout(title_text=f'These images might come from {pred} with confidence: {confidence:.2f}')
    fig.show()



Model loding


In [24]:
#only to save the model, not need for testing
torch.save(model.state_dict(), 'model_weights.pth')


In [25]:
#run this if want to load the model
model = MedIntResNet(num_classes=num_classes)
model.load_state_dict(torch.load('model_weights.pth'))
model.to(device)
model.eval()


The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.


Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.



MedIntResNet(
  (shared_conv): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_st