In [17]:
#to be able to run your async code in the notebook
import nest_asyncio
nest_asyncio.apply()
import os
import subprocess
from typing import List, Tuple, Optional
import pyarrow.parquet as pq
import vt
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm  # tqdm.notebook for Jupyter notebook
from dotenv import load_dotenv
from PyPDF2 import PdfMerger
import math
import requests
from datetime import datetime

In [18]:
mode = 'benign'  # You can change this to 'benign' to read from the benign dataset

# Maximum of api calls for VirusTotal, current academic api is 20k per day
batch_size = 1

## DomainAnalyzer
**Objective**: Define the `DomainAnalyzer` class that will handle domain analysis tasks.

- **Functions Included**:
    - `__init__`: Initializes the `DomainAnalyzer` with a VirusTotal API key loaded from an environment variable.
    - `__enter__` and `__exit__`: Context management methods to handle the setup and cleanup of the client.
    - `initialize_client`: Load API key and initialize the vt.Client.
    - `check_domain`: Fetch information for a specific domain.
    - `get_verdict`: Determine the verdict of the analysis based on the domain's analysis stats.
    - `is_domain_live`: Check if a domain is live by calling a bash script.
    - `extract_domain_data`: Extract necessary data from the domain result.
    - `load_previous_data`: Load previously processed domain data from a CSV file.
    - `save_data`: Save the DataFrame containing domain data to a CSV file.
    - `generate_report`: Generate a report based on the DataFrame and save it as a PDF.
    - `process_selected_domains`: Process the domains based on the mode ('malign' or 'benign') in batches.

In [19]:
class DomainAnalyzer:
    def __init__(self):
        self.api_key = self.load_api_key()
        self.headers = {
            "x-apikey": self.api_key,
            "Accept": "application/json"
        }

    @staticmethod
    def load_api_key():
        load_dotenv()  # Load environment variables from .env file
        api_key = os.getenv('VT_API_KEY')  # Get the API key from the environment variable

        if api_key is None:
            raise ValueError("API key is not set. Please set the VT_API_KEY environment variable.")

        return api_key

    def __enter__(self):
        """
        Enter the runtime context for the DomainAnalyzer.
        """
        return self
    
    def __exit__(self, exc_type, exc_value, traceback):
        """
        Exit the runtime context for the DomainAnalyzer.
        """
        pass
    
    def check_domain(self, domain: str) -> Optional[dict]:
        url = f"https://www.virustotal.com/api/v3/domains/{domain}"
        response = requests.get(url, headers=self.headers)
        
        if response.status_code == 200:
            return response.json()
        else:
            print(f"Error: Unable to fetch information for domain {domain}. {response.text}")
            return None

    def get_verdict(self, analysis_stats: dict) -> str:
        """
        Determine the verdict of the analysis.
        """
        if analysis_stats.get('malicious', 0) > 0 or analysis_stats.get('suspicious', 0) > 1:
            return "Malign"
        else:
            return "Benign"
        
    def is_domain_live(self, domain: str) -> str:
        """
        Check if a domain is live by calling a bash script.
        """
        try:
            # Running the bash script and capturing the output
            result = subprocess.run(['./livetest.sh', domain], capture_output=True, text=True)
            output = result.stdout.strip()
            if output == '1':
                return "Alive"
            else:
                return "Dead"
        except Exception as e:
            print(f"Error: Unable to check if domain {domain} is live. {e}")
            return "Unknown"
        
    def extract_domain_data(self, domain: str, result: dict) -> Tuple:
        """
        Extract necessary data from the domain result.
        """
        try:
            analysis_stats = result['data']['attributes']['last_analysis_stats']
        except KeyError:
            print(f"Error: Could not extract analysis stats for domain {domain}")
            return None  # 

        verdict = self.get_verdict(analysis_stats)
        detection_ratio = f"{analysis_stats['malicious']}/{analysis_stats['malicious'] + analysis_stats['harmless']}"
        
        try:
            detection_timestamp = result['data']['attributes']['last_analysis_date']
            # Convert from UNIX epoch format to datetime object
            dt_obj = datetime.utcfromtimestamp(detection_timestamp)
            # Format to desired string format
            formatted_timestamp = dt_obj.strftime('%Y-%m-%d %H:%M:%S')
        except KeyError:
            print(f"Error: Could not extract last analysis date for domain {domain}")
            return None
        except ValueError:
            print(f"Error: Could not convert last analysis date for domain {domain}")
            return None
        
        domain_status = self.is_domain_live(domain)
        return domain, verdict, detection_ratio, formatted_timestamp, analysis_stats.get('harmless', 0), \
               analysis_stats.get('malicious', 0), analysis_stats.get('suspicious', 0), domain_status
    
    def load_previous_data(self) -> pd.DataFrame:
        """
        Load previously processed domain data from a CSV file.
        """
        previous_data_filename = f'previous_data_{mode}.csv'
        if os.path.exists(previous_data_filename):
            return pd.read_csv(previous_data_filename)
        else:
            columns = ["Domain", "Verdict", "Detection Ratio", "Detection Timestamp", "Harmless", "Malicious", "Suspicious", "Live Status"]
            return pd.DataFrame(columns=columns)

    def save_data(self, df: pd.DataFrame) -> None:
        """
        Save the DataFrame containing domain data to a CSV file.
        """
        df.to_csv(f'previous_data_{mode}.csv', index=False)

    def generate_report(self, df: pd.DataFrame, output_filename: str, rows_per_page: int = 500) -> None:
        """
        Generate a report based on the DataFrame and save it as a PDF.
        """
        pdf_merger = PdfMerger()

        num_pages = math.ceil(len(df) / rows_per_page)
        for page in range(num_pages):
            start_row = page * rows_per_page
            end_row = start_row + rows_per_page
            page_df = df[start_row:end_row]

        previous_data_filename = f'previous_data_{mode}.csv'
        if os.path.exists(previous_data_filename):
            old_df = pd.read_csv(previous_data_filename)
            merged_df = pd.concat([old_df, df]).drop_duplicates(subset=['Domain']).reset_index(drop=True)
        else:
            merged_df = df

        # Save the merged data for future use
        merged_df.to_csv(previous_data_filename, index=False)

        benign_count = len(df[df['Verdict'] == 'Benign'])
        malign_count = len(df[df['Verdict'] == 'Malign'])
        total_count = len(df)
        
        benign_row = pd.DataFrame([['', 'Benign count', f'{benign_count}/{total_count}', '', '', '', '', '']], columns=df.columns)
        malign_row = pd.DataFrame([['', 'Malign count', f'{malign_count}/{total_count}', '', '', '', '', '']], columns=df.columns)
        
        df = pd.concat([df, benign_row, malign_row], ignore_index=True)
        # Adjust the height of the figure based on the number of rows in the DataFrame
        fig_height = len(page_df) * 0.15
        fig, ax = plt.subplots(figsize=(12, fig_height))
        ax.axis('off')  # Hide axes
        plt.tight_layout(pad=0.1)
        
        colWidths = [max(df["Domain"].apply(len)*0.20) * 0.02 if column == "Domain" 
             else 0.12 if column == "Detection Timestamp" 
             else 0.10 for column in df.columns]
        
        tab = pd.plotting.table(ax, df, loc='upper center', colWidths=colWidths, cellLoc='center', rowLoc='center')
        tab.auto_set_font_size(True) 
        tab.set_fontsize(8)  
        tab.scale(1.2, 1.2)

        # Style adjustments (bold headers, colors based on verdict, hiding index)
        for key, cell in tab.get_celld().items():
            if key[0] == 0 or key[1] == -1:
                cell._text.set_weight('bold')
            if cell.get_text().get_text() == 'Malign':
                cell._text.set_color('red')
            elif cell.get_text().get_text() == 'Benign':
                cell._text.set_color('green')
            if key[1] == -1:
                cell.set_visible(False)
            if key[0] in [total_count+1, total_count+2]:  # Special styling for the benign and malign count rows
                cell._text.set_weight('bold')
                cell.set_facecolor('lightgrey')
            if cell.get_text().get_text() == 'Dead':
                cell._text.set_color('red')
            elif cell.get_text().get_text() == 'Alive':
                cell._text.set_color('green')
        
        # Save the table as a PDF
        page_filename = f"temp_page_{page+1}.pdf"
        plt.savefig(page_filename, bbox_inches='tight', dpi=300)
        plt.close()

        # Add each page to the merger
        pdf_merger.append(page_filename)

        # Save the merged PDF
        with open(output_filename, 'wb') as merged_pdf:
            pdf_merger.write(merged_pdf)

        # Clean up the temporary files, if any
        for page in range(num_pages):
            page_filename = f"temp_page_{page+1}.pdf"
            if os.path.exists(page_filename):
                os.remove(page_filename)

    def process_selected_domains(self, mode: str, batch_size) -> pd.DataFrame:
        """
        Process the domains based on the mode ('malign' or 'benign') in batches.
        """
        paths = {
            'malign': '../floor/misp_2307.parquet',
            'benign': '../floor/benign_cesnet_union_2307.parquet'
        }

        # Check if the mode is valid
        if mode not in paths:
            print(f"Invalid mode '{mode}'. Please use 'malign' or 'benign'.")
            return

        # Read the selected Parquet file and get the domain names
        table = pq.read_table(paths[mode])
        domain_names = table.column('domain_name').to_pandas()

        # Load the processed domains
        processed_domains_file = f"processed_domains_{mode}.txt"

        
        if os.path.exists(processed_domains_file):
            with open(processed_domains_file, 'r') as file:
                processed_domains = file.read().splitlines()
        else:
            processed_domains = []

        data = []
        processed_in_this_run = 0
        for domain in tqdm(domain_names, desc='Processing domains', unit='domain'):
            if processed_in_this_run >= batch_size:  # Check if the batch size is reached
                break  # Exit the loop if the batch size is reached
            if domain not in processed_domains:
                result = self.check_domain(domain)
                if result:
                    data.append(self.extract_domain_data(domain, result))
                    processed_domains.append(domain)  # Mark domain as processed
                    processed_in_this_run += 1  # Increment the processed counter


        # Save the updated processed domains
        with open(processed_domains_file, 'w') as file:
            file.write('\n'.join(processed_domains))

        columns = ["Domain", "Verdict", "Detection Ratio", "Detection Timestamp", "Harmless", "Malicious", "Suspicious", "Live Status"]
        
        # Create a DataFrame from the newly processed data
        new_df = pd.DataFrame(data, columns=columns)
        old_df = self.load_previous_data()

        if old_df.empty:
            merged_df = new_df
        elif new_df.empty:
            merged_df = old_df
        else:
            merged_df = pd.concat([old_df, new_df]).drop_duplicates(subset=['Domain']).reset_index(drop=True)
        
        merged_df.sort_values(by=['Verdict', 'Live Status'], ascending=[False, False], inplace=True)
        merged_df.dropna(inplace=True)
        # Save the merged data
        self.save_data(merged_df)
        #print how many domains were processed in total, also include percentages
        print(f"Total number of domains processed: {len(merged_df)} out of {len(domain_names)} ({len(merged_df)/len(domain_names)*100:.2f}%)")
        return merged_df


**Objective**: Utilize the `DomainAnalyzer` class to process and analyze domains.

- **Steps**:
    1. Instantiate the `DomainAnalyzer` class.
    2. Use the `process_selected_domains` method to process domains based on the specified mode and batch size.
    3. If domains are processed successfully, generate and save a report using the `generate_report` method.

**Note**: Ensure that you have the necessary files, API keys, and configurations before running this cell.


In [20]:
# Example usage in a Jupyter notebook cell:
with DomainAnalyzer() as analyzer:  # Using the analyzer as a context manager
    df = analyzer.process_selected_domains(mode, batch_size)  # This should generate your DataFrame df
    if df is not None and not df.empty:  # Ensure that df is not empty or None
        analyzer.generate_report(df, f'../false_positives/VT/{mode}_VT_check.pdf')  # This will use the DataFrame df
        print(f'Report saved to ../false_positives/VT/{mode}_VT_check.pdf')
    else:
        print(f"No domains processed for mode '{mode}'. No report generated.")

Processing domains:   0%|          | 0/486250 [00:00<?, ?domain/s]

Total number of domains processed: 1 out of 486250 (0.00%)
Report saved to ../false_positives/VT/benign_VT_check.pdf
