### Version 1

In [None]:

import torch
from transformers import AutoModelForImageClassification

class WrappedModel(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base = base_model

    def forward(self, x):
        # Return only the tensor, not the dict
        return self.base(x).logits

# Load and wrap your model
model_dir = "./mobilenetv2-indianfood"
base_model = AutoModelForImageClassification.from_pretrained(model_dir)
wrapped_model = WrappedModel(base_model).eval()

# Create dummy input
dummy = torch.randn(1, 3, 224, 224)

# Trace the wrapped model
traced = torch.jit.trace(wrapped_model, dummy)
traced.save("indianfood_mobilenetv2.pt")

print("✅ TorchScript model exported successfully → indianfood_mobilenetv2.pt")


✅ TorchScript model exported successfully → indianfood_mobilenetv2.pt


### VERSION 2

In [None]:
from transformers import AutoModelForImageClassification

model_dir = "./mobilenetv2-indianfood"
model = AutoModelForImageClassification.from_pretrained(
    model_dir,
    ignore_mismatched_sizes=True
)

# Make sure it has 20 labels
print("num_labels =", model.config.num_labels)

# If it still says 10, reinitialize head
# model.classifier is Linear, not a list; no need to reinitialize if out_features=20
model.config.num_labels = 20
model.config.id2label = {i: f"label_{i}" for i in range(20)}  # use your actual labels
wrapped_model = WrappedModel(model).eval()

# Save model & re-export TorchScript
import torch
dummy = torch.randn(1, 3, 224, 224)
traced = torch.jit.trace(wrapped_model, dummy)
traced.save("indianfood_mobilenetv2.pt")
print("✅ Re-exported TorchScript with 20 outputs")


num_labels = 20
✅ Re-exported TorchScript with 20 outputs


### Test

In [12]:
import torch
model = torch.jit.load("indianfood_mobilenetv2.pt")
x = torch.randn(1, 3, 224, 224)
out = model(x)
print(out.shape)

torch.Size([1, 20])


In [3]:
from transformers import AutoModelForImageClassification

model_dir = "./mobilenetv2-indianfood"
model = AutoModelForImageClassification.from_pretrained(model_dir)

labels = [model.config.id2label[i] for i in range(model.config.num_labels)]

with open("labels.txt", "w") as f:
    f.write("\n".join(labels))

print("✅ Saved labels.txt from model config")


✅ Saved labels.txt from model config
