In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import torch
from torchvision import transforms
from time import sleep
from PIL import Image

import utility_modules.move_ctype  as cici
from utility_modules.capture import capture_mode

In [2]:
class MousePositionCNN(nn.Module):
    def __init__(self):
        super(MousePositionCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 45 * 80, 512)  # Adjust dimensions based on image size and pooling
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 2)  # Output layer with 2 values (x, y)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 45 * 80)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class ClassificationCNN(nn.Module):
    def __init__(self):
        super(ClassificationCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(64 * 45 * 80, 512)  # Adjust dimensions based on image size and pooling
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 2)  # Output layer with 2 classes
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 45 * 80)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [3]:
# Instantiate the models
class_model = ClassificationCNN()
MA_model = MousePositionCNN()
ME_model = MousePositionCNN()

# Load the saved state dictionaries
class_model.load_state_dict(torch.load('class_torch.pth'))
MA_model.load_state_dict(torch.load('MA_torch.pth'))
ME_model.load_state_dict(torch.load('ME_torch.pth'))

# Set the models to evaluation mode
class_model.eval()
MA_model.eval()
ME_model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_model.to(device)
MA_model.to(device)
ME_model.to(device)
print('done')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((360, 640)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

done


In [8]:
def roonie(model_classifier, model_explorer, model_attacker):
    last_three_preds = [3]
    for n in range(50):
        try:
            # if last_three_preds[-1] == 0:
            #     cici.release_right_button()
            #     img = capture_mode('desired', (0, 0, 1920, 1080))
            # elif last_three_preds[-1] == 1:
            #     cici.release_left_button()
            #     img = capture_mode('desired', (0, 0, 1920, 1080))
            # else:
            #     img = capture_mode('desired', (0, 0, 1920, 1080))
            img = capture_mode('desired', (0, 0, 1920, 1080))
            resized = cv2.resize(img, (640, 360))

            tbp = Image.fromarray(cv2.cvtColor(resized, cv2.COLOR_BGR2RGB))
            tbp = transform(tbp).unsqueeze(0).to(device)

            with torch.no_grad():
                predictions = torch.sigmoid(model_classifier(tbp)).tolist()[0]
            print(predictions)

            max_value = max(predictions)
            max_index = predictions.index(max_value)
            last_three_preds.append(max_index)

            if len(last_three_preds) == 5:
                last_three_preds.pop(0)

            if all(ele == 0 for ele in last_three_preds) and len(last_three_preds) == 4:
                print('before', max_index)
                max_index = 0 if max_index == 1 else 1
                last_three_preds[-1] = max_index
                print('after', max_index)

            if last_three_preds[-1] != max_index or (all(ele == last_three_preds[-1] for ele in last_three_preds) and len(last_three_preds) == 4):
                if max_index == 0:
                    cici.release_left_button()
                if max_index == 1:
                    cici.release_right_button()

            if max_index == 0:
                print('attacking')
                with torch.no_grad():
                    predictions = model_attacker(tbp).tolist()[0]
                    print(predictions)
                x, y = int(predictions[0]), int(predictions[1])

                x = x + 960 if x <= 960 else x + 960
                y = y + 540 if y <= 540 else y + 540
                
                print(x,y)

                cici.move_cursor_steps(x, y)
                cici.press_right_button()
                

            elif max_index == 1:
                print('exploring')
                with torch.no_grad():
                    predictions = model_explorer(tbp).tolist()[0]
                    print(predictions)
                x, y = int(predictions[0]), int(predictions[1])
                
                x = x + 960 if x <= 960 else x + 960
                y = y + 540 if y <= 540 else y + 540
                
                
                print(x,y)

                cici.press_left_button()
                cici.move_cursor_steps(x, y)    
                    
            cv2.imwrite(f"predicted_data/{'attack' if max_index == 0 else 'explore'}_{x}_{y}.jpg", resized)
            sleep(0.2)
                

        except Exception as e:
            print(e)
            
            
print('done')


done


In [36]:
sleep(2)
roonie(class_model,ME_model, MA_model)

[0.09193383902311325, 0.9201951026916504]
exploring
[-263.1145935058594, 54.7701530456543]
697 594
[0.33316436409950256, 0.709123432636261]
exploring
[-319.0459289550781, -114.21417236328125]
641 426
[0.9987242817878723, 0.0015492589445784688]
attacking
[-349.9613342285156, 76.44205474853516]
611 616
[1.7853904864750803e-05, 0.9999842643737793]
exploring
[73.30320739746094, -171.98214721679688]
1033 369
[0.9823943972587585, 0.01948145218193531]
attacking
[183.47647094726562, -39.135555267333984]
1143 501
[0.25843653082847595, 0.741837203502655]
exploring
[49.834224700927734, -265.4919128417969]
1009 275
[0.9982958436012268, 0.001492162118665874]
attacking
[53.07361602783203, 204.6667938232422]
1013 744
[0.9520118832588196, 0.05221624672412872]
attacking
[-286.7071838378906, 72.73279571533203]
674 612
[0.9999892711639404, 1.1010115485987626e-05]
attacking
[-183.397216796875, 64.385986328125]
777 604
[0.9991917014122009, 0.0008562697330489755]
before 0
after 1
exploring
[-446.58322143554