In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!cp /content/drive/MyDrive/DoAnCV/eff_dataset.zip /content/dataset.zip
!unzip '/content/dataset.zip' -d '/content/dataset'

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/dataset/eff_dataset/train/class_8/149564_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149629_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149781_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149782_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149786_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149913_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149916_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149917_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149919_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149920_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149922_bbox0.jpg  
  inflating: /content/dataset/eff_dataset/train/class_8/149923_bbox0.jpg  
  inflating: /content/dataset/eff_d

In [4]:
!pip install pyngrok

Collecting pyngrok
  Downloading pyngrok-7.2.9-py3-none-any.whl.metadata (9.3 kB)
Downloading pyngrok-7.2.9-py3-none-any.whl (25 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.9


In [14]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, models
from PIL import Image
import requests
from io import BytesIO
from flask import Flask, request, jsonify, send_file
from pyngrok import ngrok, conf
import os
import numpy as np

app = Flask(__name__)

NUM_CLASSES = 13
IMG_SIZE = 224
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "/content/drive/MyDrive/DoAnCV/best_cnn_resnet50.pth"
TRAIN_DIR = "/content/dataset/eff_dataset/train"

test_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def load_model(model_path):
    model = models.resnet50(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, NUM_CLASSES)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model = model.to(DEVICE)
    model.eval()
    return model

def predict_image(image, model, class_names):
    image = test_transforms(image)
    image = image.unsqueeze(0)
    image = image.to(DEVICE)

    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.softmax(outputs, dim=1)[0]
        _, predicted = torch.max(outputs, 1)
        predicted_class = class_names[predicted.item()]
        confidence = probabilities[predicted.item()].item() * 100

    return predicted_class, confidence, probabilities.cpu().numpy()

# class_names = sorted(os.listdir(TRAIN_DIR)) if os.path.exists(TRAIN_DIR) else [f"class{i}" for i in range(1, NUM_CLASSES + 1)]
# class_names = sorted(os.listdir(TRAIN_DIR))
class_names = ['class_0', 'class_1', 'class_10', 'class_11', 'class_12', 'class_2', 'class_3', 'class_4', 'class_5', 'class_6', 'class_7', 'class_8', 'class_9']

model = load_model(MODEL_PATH)

@app.route('/predict', methods=['POST'])
def predict():
    try:
        if 'file' in request.files:
            file = request.files['file']
            image = Image.open(file.stream).convert('RGB')
        elif 'url' in request.form:
            url = request.form['url']
            response = requests.get(url)
            response.raise_for_status()
            image = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            return jsonify({"error": "No file or URL provided"}), 400

        predicted_class, confidence, probabilities = predict_image(image, model, class_names)

        class_vi = [
            'áo tay ngắn',
            'áo tay dài',
            'áo khoác tay ngắn',
            'áo khoác tay dài',
            'áo ghi-lê, áo ba lỗ, áo vest',
            'áo hai dây',
            'quần short',
            'quần dài',
            'váy ngắn',
            'đầm tay ngắn',
            'đầm tay dài',
            'đầm sát nách',
            'đầm hai dây'
        ]

        class_map = {f'class_{i}': class_vi[i] for i in range(len(class_vi))}

        predicted_label_vi = class_map.get(predicted_class, 'Không xác định')


        result = {
            "predicted_class": predicted_label_vi,
            "confidence": f"{confidence:.2f}",
            "probabilities": {class_names[i]: f"{prob*100:.2f}%" for i, prob in enumerate(probabilities)}
        }
        return jsonify(result)
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/')
@app.route('/index.html')
def serve_html():
    return send_file("index.html")

HTML_CONTENT = """
<!DOCTYPE html>
<html>
<head>
    <title>CNN Image Prediction Demo</title>
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <script src="https://unpkg.com/react@18.2.0/umd/react.production.min.js"></script>
    <script src="https://unpkg.com/react-dom@18.2.0/umd/react-dom.production.min.js"></script>
    <script src="https://unpkg.com/@babel/standalone@7.25.7/babel.min.js"></script>
    <script src="https://cdn.tailwindcss.com"></script>
    <style>
        .spinner {
            border: 4px solid rgba(0, 0, 0, 0.1);
            border-left-color: #3b82f6;
            border-radius: 50%;
            width: 24px;
            height: 24px;
            animation: spin 1s linear infinite;
            display: inline-block;
        }
        @keyframes spin {
            to { transform: rotate(360deg); }
        }
    </style>
</head>
<body>
    <div id="root"></div>
    <script type="text/babel">
        // Error Boundary Component
        class ErrorBoundary extends React.Component {
            state = { error: null };
            static getDerivedStateFromError(error) {
                return { error: error.message };
            }
            render() {
                if (this.state.error) {
                    return (
                        <div className="text-red-500 text-center p-4">
                            Something went wrong: {this.state.error}
                        </div>
                    );
                }
                return this.props.children;
            }
        }

        function App() {
            const [imageUrl, setImageUrl] = React.useState('');
            const [imageFile, setImageFile] = React.useState(null);
            const [result, setResult] = React.useState(null);
            const [error, setError] = React.useState(null);
            const [imagePreview, setImagePreview] = React.useState(null);
            const [isLoading, setIsLoading] = React.useState(false);

            const handleUrlChange = (e) => {
                setImageUrl(e.target.value);
                setImageFile(null);
                setImagePreview(e.target.value);
                setResult(null);
                setError(null);
            };

            const handleFileChange = (e) => {
                const file = e.target.files[0];
                if (file) {
                    setImageFile(file);
                    setImageUrl('');
                    setImagePreview(URL.createObjectURL(file));
                    setResult(null);
                    setError(null);
                }
            };

            const handlePredict = async () => {
                if (!imageFile && !imageUrl) {
                    setError('Please provide an image file or URL');
                    return;
                }
                setIsLoading(true);
                const formData = new FormData();
                if (imageFile) {
                    formData.append('file', imageFile);
                } else if (imageUrl) {
                    formData.append('url', imageUrl);
                }

                try {
                    const response = await fetch('/predict', {
                        method: 'POST',
                        body: formData
                    });
                    const data = await response.json();
                    if (data.error) {
                        setError(data.error);
                        setResult(null);
                    } else {
                        setResult(data);
                        setError(null);
                    }
                } catch (err) {
                    setError('Failed to connect to the server');
                    setResult(null);
                } finally {
                    setIsLoading(false);
                }
            };

            return (
                <div className="min-h-screen bg-gray-100 flex items-center justify-center p-4 sm:p-6">
                    <div className="w-full max-w-lg bg-white rounded-lg shadow-lg p-6 sm:p-8">
                        <h1 className="text-2xl sm:text-3xl font-bold mb-6 text-center text-gray-800">
                            CNN Image Prediction Demo
                        </h1>
                        <div className="space-y-4">
                            <input
                                type="text"
                                placeholder="Paste image URL here"
                                value={imageUrl}
                                onChange={handleUrlChange}
                                className="w-full p-3 border rounded-lg focus:ring-2 focus:ring-blue-500 focus:border-blue-500 transition"
                            />
                            <input
                                type="file"
                                accept="image/*"
                                onChange={handleFileChange}
                                className="w-full text-sm text-gray-500 file:mr-4 file:py-2 file:px-4 file:rounded file:border-0 file:bg-blue-50 file:text-blue-700 hover:file:bg-blue-100"
                            />
                            <button
                                onClick={handlePredict}
                                disabled={isLoading}
                                className={`w-full py-3 rounded-lg font-semibold text-white transition duration-300 ${
                                    isLoading ? 'bg-blue-400 cursor-not-allowed' : 'bg-blue-500 hover:bg-blue-600'
                                } flex items-center justify-center`}
                            >
                                {isLoading ? (
                                    <>
                                        <span className="spinner mr-2"></span>
                                        Predicting...
                                    </>
                                ) : (
                                    'Predict'
                                )}
                            </button>
                            {imagePreview && (
                                <div className="mt-4 flex justify-center">
                                    <img
                                        src={imagePreview}
                                        alt="Preview"
                                        className="max-w-full h-auto rounded-lg shadow-md max-h-64"
                                    />
                                </div>
                            )}
                            {error && (
                                <div className="text-red-500 mt-4 text-center font-medium">
                                    Error: {error}
                                </div>
                            )}
                            {result && (
                                <div className="mt-6 text-center bg-gray-50 p-4 rounded-lg">
                                    <p className="text-lg font-semibold text-gray-800">
                                        Predicted Class: {result.predicted_class}
                                    </p>
                                    <p className="text-gray-600">Confidence: {result.confidence}</p>
                                    <h3 className="mt-3 font-medium text-gray-700">Class Probabilities:</h3>
                                    <ul className="mt-2 text-sm text-gray-600 space-y-1">
                                        {Object.entries(result.probabilities).map(([cls, prob]) => (
                                            <li key={cls}>{cls}: {prob}</li>
                                        ))}
                                    </ul>
                                </div>
                            )}
                        </div>
                    </div>
                </div>
            );
        }

        const root = ReactDOM.createRoot(document.getElementById('root'));
        root.render(
            <ErrorBoundary>
                <App />
            </ErrorBoundary>
        );
    </script>
</body>
</html>
"""

with open("index.html", "w") as f:
    f.write(HTML_CONTENT)

def main():
    authtoken = "Auth here"
    conf.get_default().auth_token = authtoken

    public_url = ngrok.connect(5000, proto="http").public_url
    print(f" * ngrok tunnel available at: {public_url}")
    print(f" * Open the frontend at: {public_url}/index.html")

    app.run(host="0.0.0.0", port=5000)

if __name__ == "__main__":
    main()

 * ngrok tunnel available at: https://cd3f-34-125-114-14.ngrok-free.app
 * Open the frontend at: https://cd3f-34-125-114-14.ngrok-free.app/index.html
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.28.0.12:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:35:08] "GET / HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:35:10] "[33mGET /favicon.ico HTTP/1.1[0m" 404 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:35:13] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:35:39] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:35:49] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:36:08] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:36:14] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:36:17] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:36:25] "[36mGET / HTTP/1.1[0m" 304 -
INFO:werkzeug:127.0.0.1 - - [04/Jun/2025 14:36:29] "POST /predict HTTP/1.1" 200 -
INFO:werkzeug:127.0